From 0efca3cadbad5517a59884576ddb90cfe7ac30f8 Mon Sep 17 00:00:00 2001 From: Diego Russo Date: Mon, 30 May 2022 13:34:14 +0100 Subject: Add MLIA codebase Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd --- .gitignore | 15 + LICENSES/Apache-2.0.txt | 177 +++ LICENSES/BSD-3-Clause.txt | 26 + LICENSES/CC-PDDC.txt | 28 + LICENSES/MIT.txt | 21 + MANIFEST.in | 5 + README.md | 312 ++++++ pyproject.toml | 78 ++ setup.cfg | 79 ++ setup.py | 7 + src/aiet/__init__.py | 7 + src/aiet/backend/__init__.py | 3 + src/aiet/backend/application.py | 187 ++++ src/aiet/backend/common.py | 532 +++++++++ src/aiet/backend/config.py | 107 ++ src/aiet/backend/controller.py | 134 +++ src/aiet/backend/execution.py | 859 +++++++++++++++ src/aiet/backend/output_parser.py | 176 +++ src/aiet/backend/protocol.py | 325 ++++++ src/aiet/backend/source.py | 209 ++++ src/aiet/backend/system.py | 289 +++++ src/aiet/backend/tool.py | 109 ++ src/aiet/cli/__init__.py | 28 + src/aiet/cli/application.py | 362 ++++++ src/aiet/cli/common.py | 173 +++ src/aiet/cli/completion.py | 72 ++ src/aiet/cli/system.py | 122 +++ src/aiet/cli/tool.py | 143 +++ src/aiet/main.py | 13 + src/aiet/resources/applications/.gitignore | 6 + src/aiet/resources/systems/.gitignore | 6 + src/aiet/resources/tools/vela/aiet-config.json | 73 ++ .../resources/tools/vela/aiet-config.json.license | 3 + src/aiet/resources/tools/vela/check_model.py | 75 ++ src/aiet/resources/tools/vela/run_vela.py | 65 ++ src/aiet/resources/tools/vela/vela.ini | 53 + src/aiet/utils/__init__.py | 3 + src/aiet/utils/fs.py | 116 ++ src/aiet/utils/helpers.py | 17 + src/aiet/utils/proc.py | 283 +++++ src/mlia/__init__.py | 22 + src/mlia/api.py | 162 +++ src/mlia/cli/__init__.py | 3 + src/mlia/cli/commands.py | 276 +++++ src/mlia/cli/common.py | 38 + src/mlia/cli/config.py | 64 ++ src/mlia/cli/helpers.py | 116 ++ src/mlia/cli/logging.py | 117 ++ src/mlia/cli/main.py | 280 +++++ src/mlia/cli/options.py | 280 +++++ src/mlia/core/__init__.py | 21 + src/mlia/core/_typing.py | 12 + src/mlia/core/advice_generation.py | 106 ++ src/mlia/core/advisor.py | 21 + src/mlia/core/common.py | 47 + src/mlia/core/context.py | 218 ++++ src/mlia/core/data_analysis.py | 70 ++ src/mlia/core/data_collection.py | 37 + src/mlia/core/errors.py | 18 + src/mlia/core/events.py | 455 ++++++++ src/mlia/core/helpers.py | 38 + src/mlia/core/mixins.py | 54 + src/mlia/core/performance.py | 47 + src/mlia/core/reporting.py | 762 +++++++++++++ src/mlia/core/workflow.py | 216 ++++ src/mlia/devices/__init__.py | 3 + src/mlia/devices/config.py | 11 + src/mlia/devices/ethosu/__init__.py | 3 + src/mlia/devices/ethosu/advice_generation.py | 209 ++++ src/mlia/devices/ethosu/advisor.py | 151 +++ src/mlia/devices/ethosu/config.py | 89 ++ src/mlia/devices/ethosu/data_analysis.py | 154 +++ src/mlia/devices/ethosu/data_collection.py | 188 ++++ src/mlia/devices/ethosu/events.py | 24 + src/mlia/devices/ethosu/handlers.py | 146 +++ src/mlia/devices/ethosu/operators.py | 14 + src/mlia/devices/ethosu/performance.py | 257 +++++ src/mlia/devices/ethosu/reporters.py | 398 +++++++ src/mlia/nn/__init__.py | 3 + src/mlia/nn/tensorflow/__init__.py | 3 + src/mlia/nn/tensorflow/config.py | 134 +++ src/mlia/nn/tensorflow/optimizations/__init__.py | 3 + src/mlia/nn/tensorflow/optimizations/clustering.py | 109 ++ src/mlia/nn/tensorflow/optimizations/common.py | 29 + src/mlia/nn/tensorflow/optimizations/pruning.py | 168 +++ src/mlia/nn/tensorflow/optimizations/select.py | 179 +++ src/mlia/nn/tensorflow/tflite_metrics.py | 296 +++++ src/mlia/nn/tensorflow/utils.py | 149 +++ .../resources/aiet/applications/APPLICATIONS.txt | 6 + .../aiet-config.json | 18 + .../aiet-config.json.license | 3 + .../ethos-u-inference_runner.axf | Bin 0 -> 426496 bytes .../ethos-u-inference_runner.axf.license | 31 + .../aiet-config.json | 15 + .../aiet-config.json.license | 3 + .../ethos-u-inference_runner.axf | Bin 0 -> 426544 bytes .../ethos-u-inference_runner.axf.license | 31 + .../aiet-config.json | 15 + .../aiet-config.json.license | 3 + .../ethos-u-inference_runner.axf | Bin 0 -> 2524028 bytes .../ethos-u-inference_runner.axf.license | 31 + .../aiet-config.json | 15 + .../aiet-config.json.license | 3 + .../ethos-u-inference_runner.axf | Bin 0 -> 426488 bytes .../ethos-u-inference_runner.axf.license | 31 + .../aiet-config.json | 15 + .../aiet-config.json.license | 3 + .../ethos-u-inference_runner.axf | Bin 0 -> 426536 bytes .../ethos-u-inference_runner.axf.license | 31 + src/mlia/resources/aiet/systems/SYSTEMS.txt | 10 + .../aiet/systems/corstone-300-vht/aiet-config.json | 80 ++ .../corstone-300-vht/aiet-config.json.license | 3 + .../aiet/systems/corstone-300/aiet-config.json | 80 ++ .../systems/corstone-300/aiet-config.json.license | 3 + .../aiet/systems/corstone-310-vht/aiet-config.json | 42 + .../corstone-310-vht/aiet-config.json.license | 3 + .../aiet/systems/corstone-310/aiet-config.json | 42 + .../systems/corstone-310/aiet-config.json.license | 3 + src/mlia/resources/profiles.json | 20 + src/mlia/resources/profiles.json.license | 3 + src/mlia/resources/vela/vela.ini | 75 ++ src/mlia/tools/__init__.py | 3 + src/mlia/tools/aiet_wrapper.py | 435 ++++++++ src/mlia/tools/metadata/__init__.py | 3 + src/mlia/tools/metadata/common.py | 290 +++++ src/mlia/tools/metadata/corstone.py | 402 +++++++ src/mlia/tools/vela_wrapper.py | 500 +++++++++ src/mlia/utils/__init__.py | 3 + src/mlia/utils/console.py | 97 ++ src/mlia/utils/download.py | 89 ++ src/mlia/utils/filesystem.py | 124 +++ src/mlia/utils/logging.py | 120 ++ src/mlia/utils/misc.py | 9 + src/mlia/utils/proc.py | 164 +++ src/mlia/utils/types.py | 37 + tests/__init__.py | 3 + tests/aiet/__init__.py | 3 + tests/aiet/conftest.py | 139 +++ tests/aiet/test_backend_application.py | 452 ++++++++ tests/aiet/test_backend_common.py | 486 +++++++++ tests/aiet/test_backend_controller.py | 160 +++ tests/aiet/test_backend_execution.py | 526 +++++++++ tests/aiet/test_backend_output_parser.py | 152 +++ tests/aiet/test_backend_protocol.py | 231 ++++ tests/aiet/test_backend_source.py | 199 ++++ tests/aiet/test_backend_system.py | 536 +++++++++ tests/aiet/test_backend_tool.py | 60 + tests/aiet/test_check_model.py | 162 +++ tests/aiet/test_cli.py | 37 + tests/aiet/test_cli_application.py | 1153 ++++++++++++++++++++ tests/aiet/test_cli_common.py | 37 + tests/aiet/test_cli_system.py | 240 ++++ tests/aiet/test_cli_tool.py | 333 ++++++ tests/aiet/test_main.py | 16 + tests/aiet/test_resources/application_config.json | 96 ++ .../test_resources/application_config.json.license | 3 + .../applications/application1/aiet-config.json | 30 + .../application1/aiet-config.json.license | 3 + .../applications/application2/aiet-config.json | 30 + .../application2/aiet-config.json.license | 3 + .../applications/application3/readme.txt | 4 + .../applications/application4/aiet-config.json | 35 + .../application4/aiet-config.json.license | 3 + .../applications/application4/hello_app.txt | 4 + .../applications/application5/aiet-config.json | 160 +++ .../application5/aiet-config.json.license | 3 + tests/aiet/test_resources/applications/readme.txt | 4 + tests/aiet/test_resources/hello_world.json | 54 + tests/aiet/test_resources/hello_world.json.license | 3 + tests/aiet/test_resources/scripts/test_backend_run | 8 + .../scripts/test_backend_run_script.sh | 8 + .../systems/system1/aiet-config.json | 35 + .../systems/system1/aiet-config.json.license | 3 + .../systems/system1/system_artifact/dummy.txt | 2 + .../systems/system2/aiet-config.json | 32 + .../systems/system2/aiet-config.json.license | 3 + .../aiet/test_resources/systems/system3/readme.txt | 4 + .../systems/system4/aiet-config.json | 19 + .../systems/system4/aiet-config.json.license | 3 + .../test_resources/tools/tool1/aiet-config.json | 30 + .../tools/tool1/aiet-config.json.license | 3 + .../test_resources/tools/tool2/aiet-config.json | 26 + .../tools/tool2/aiet-config.json.license | 3 + .../application_with_empty_config/aiet-config.json | 1 + .../aiet-config.json.license | 3 + .../application_with_valid_config/aiet-config.json | 35 + .../aiet-config.json.license | 3 + .../aiet-config.json | 2 + .../aiet-config.json.license | 3 + .../aiet-config.json | 30 + .../aiet-config.json.license | 3 + .../aiet-config.json | 35 + .../aiet-config.json.license | 3 + .../system_with_empty_config/aiet-config.json | 1 + .../aiet-config.json.license | 3 + .../system_with_valid_config/aiet-config.json | 16 + .../aiet-config.json.license | 3 + tests/aiet/test_run_vela_script.py | 152 +++ tests/aiet/test_utils_fs.py | 168 +++ tests/aiet/test_utils_helpers.py | 27 + tests/aiet/test_utils_proc.py | 272 +++++ tests/conftest.py | 95 ++ tests/mlia/__init__.py | 3 + tests/mlia/conftest.py | 20 + tests/mlia/test_api.py | 96 ++ tests/mlia/test_cli_commands.py | 204 ++++ tests/mlia/test_cli_config.py | 49 + tests/mlia/test_cli_helpers.py | 165 +++ tests/mlia/test_cli_logging.py | 104 ++ tests/mlia/test_cli_main.py | 357 ++++++ tests/mlia/test_cli_options.py | 186 ++++ tests/mlia/test_core_advice_generation.py | 71 ++ tests/mlia/test_core_advisor.py | 40 + tests/mlia/test_core_context.py | 62 ++ tests/mlia/test_core_data_analysis.py | 31 + tests/mlia/test_core_events.py | 155 +++ tests/mlia/test_core_helpers.py | 17 + tests/mlia/test_core_mixins.py | 99 ++ tests/mlia/test_core_performance.py | 29 + tests/mlia/test_core_reporting.py | 413 +++++++ tests/mlia/test_core_workflow.py | 164 +++ .../mlia/test_devices_ethosu_advice_generation.py | 483 ++++++++ tests/mlia/test_devices_ethosu_advisor.py | 9 + tests/mlia/test_devices_ethosu_config.py | 124 +++ tests/mlia/test_devices_ethosu_data_analysis.py | 147 +++ tests/mlia/test_devices_ethosu_data_collection.py | 151 +++ tests/mlia/test_devices_ethosu_performance.py | 28 + tests/mlia/test_devices_ethosu_reporters.py | 434 ++++++++ tests/mlia/test_nn_tensorflow_config.py | 72 ++ .../test_nn_tensorflow_optimizations_clustering.py | 131 +++ .../test_nn_tensorflow_optimizations_pruning.py | 117 ++ .../test_nn_tensorflow_optimizations_select.py | 240 ++++ tests/mlia/test_nn_tensorflow_tflite_metrics.py | 137 +++ tests/mlia/test_nn_tensorflow_utils.py | 81 ++ tests/mlia/test_resources/vela/sample_vela.ini | 47 + tests/mlia/test_tools_aiet_wrapper.py | 760 +++++++++++++ tests/mlia/test_tools_metadata_common.py | 196 ++++ tests/mlia/test_tools_metadata_corstone.py | 419 +++++++ tests/mlia/test_tools_vela_wrapper.py | 285 +++++ tests/mlia/test_utils_console.py | 100 ++ tests/mlia/test_utils_download.py | 147 +++ tests/mlia/test_utils_filesystem.py | 166 +++ tests/mlia/test_utils_logging.py | 63 ++ tests/mlia/test_utils_misc.py | 25 + tests/mlia/test_utils_proc.py | 149 +++ tests/mlia/test_utils_types.py | 77 ++ tests/mlia/utils/__init__.py | 3 + tests/mlia/utils/common.py | 32 + tests/mlia/utils/logging.py | 13 + 249 files changed, 27687 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSES/Apache-2.0.txt create mode 100644 LICENSES/BSD-3-Clause.txt create mode 100644 LICENSES/CC-PDDC.txt create mode 100644 LICENSES/MIT.txt create mode 100644 MANIFEST.in create mode 100644 README.md create mode 100644 pyproject.toml create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 src/aiet/__init__.py create mode 100644 src/aiet/backend/__init__.py create mode 100644 src/aiet/backend/application.py create mode 100644 src/aiet/backend/common.py create mode 100644 src/aiet/backend/config.py create mode 100644 src/aiet/backend/controller.py create mode 100644 src/aiet/backend/execution.py create mode 100644 src/aiet/backend/output_parser.py create mode 100644 src/aiet/backend/protocol.py create mode 100644 src/aiet/backend/source.py create mode 100644 src/aiet/backend/system.py create mode 100644 src/aiet/backend/tool.py create mode 100644 src/aiet/cli/__init__.py create mode 100644 src/aiet/cli/application.py create mode 100644 src/aiet/cli/common.py create mode 100644 src/aiet/cli/completion.py create mode 100644 src/aiet/cli/system.py create mode 100644 src/aiet/cli/tool.py create mode 100644 src/aiet/main.py create mode 100644 src/aiet/resources/applications/.gitignore create mode 100644 src/aiet/resources/systems/.gitignore create mode 100644 src/aiet/resources/tools/vela/aiet-config.json create mode 100644 src/aiet/resources/tools/vela/aiet-config.json.license create mode 100644 src/aiet/resources/tools/vela/check_model.py create mode 100644 src/aiet/resources/tools/vela/run_vela.py create mode 100644 src/aiet/resources/tools/vela/vela.ini create mode 100644 src/aiet/utils/__init__.py create mode 100644 src/aiet/utils/fs.py create mode 100644 src/aiet/utils/helpers.py create mode 100644 src/aiet/utils/proc.py create mode 100644 src/mlia/__init__.py create mode 100644 src/mlia/api.py create mode 100644 src/mlia/cli/__init__.py create mode 100644 src/mlia/cli/commands.py create mode 100644 src/mlia/cli/common.py create mode 100644 src/mlia/cli/config.py create mode 100644 src/mlia/cli/helpers.py create mode 100644 src/mlia/cli/logging.py create mode 100644 src/mlia/cli/main.py create mode 100644 src/mlia/cli/options.py create mode 100644 src/mlia/core/__init__.py create mode 100644 src/mlia/core/_typing.py create mode 100644 src/mlia/core/advice_generation.py create mode 100644 src/mlia/core/advisor.py create mode 100644 src/mlia/core/common.py create mode 100644 src/mlia/core/context.py create mode 100644 src/mlia/core/data_analysis.py create mode 100644 src/mlia/core/data_collection.py create mode 100644 src/mlia/core/errors.py create mode 100644 src/mlia/core/events.py create mode 100644 src/mlia/core/helpers.py create mode 100644 src/mlia/core/mixins.py create mode 100644 src/mlia/core/performance.py create mode 100644 src/mlia/core/reporting.py create mode 100644 src/mlia/core/workflow.py create mode 100644 src/mlia/devices/__init__.py create mode 100644 src/mlia/devices/config.py create mode 100644 src/mlia/devices/ethosu/__init__.py create mode 100644 src/mlia/devices/ethosu/advice_generation.py create mode 100644 src/mlia/devices/ethosu/advisor.py create mode 100644 src/mlia/devices/ethosu/config.py create mode 100644 src/mlia/devices/ethosu/data_analysis.py create mode 100644 src/mlia/devices/ethosu/data_collection.py create mode 100644 src/mlia/devices/ethosu/events.py create mode 100644 src/mlia/devices/ethosu/handlers.py create mode 100644 src/mlia/devices/ethosu/operators.py create mode 100644 src/mlia/devices/ethosu/performance.py create mode 100644 src/mlia/devices/ethosu/reporters.py create mode 100644 src/mlia/nn/__init__.py create mode 100644 src/mlia/nn/tensorflow/__init__.py create mode 100644 src/mlia/nn/tensorflow/config.py create mode 100644 src/mlia/nn/tensorflow/optimizations/__init__.py create mode 100644 src/mlia/nn/tensorflow/optimizations/clustering.py create mode 100644 src/mlia/nn/tensorflow/optimizations/common.py create mode 100644 src/mlia/nn/tensorflow/optimizations/pruning.py create mode 100644 src/mlia/nn/tensorflow/optimizations/select.py create mode 100644 src/mlia/nn/tensorflow/tflite_metrics.py create mode 100644 src/mlia/nn/tensorflow/utils.py create mode 100644 src/mlia/resources/aiet/applications/APPLICATIONS.txt create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf create mode 100644 src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license create mode 100644 src/mlia/resources/aiet/systems/SYSTEMS.txt create mode 100644 src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json create mode 100644 src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license create mode 100644 src/mlia/resources/aiet/systems/corstone-300/aiet-config.json create mode 100644 src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license create mode 100644 src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json create mode 100644 src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license create mode 100644 src/mlia/resources/aiet/systems/corstone-310/aiet-config.json create mode 100644 src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license create mode 100644 src/mlia/resources/profiles.json create mode 100644 src/mlia/resources/profiles.json.license create mode 100644 src/mlia/resources/vela/vela.ini create mode 100644 src/mlia/tools/__init__.py create mode 100644 src/mlia/tools/aiet_wrapper.py create mode 100644 src/mlia/tools/metadata/__init__.py create mode 100644 src/mlia/tools/metadata/common.py create mode 100644 src/mlia/tools/metadata/corstone.py create mode 100644 src/mlia/tools/vela_wrapper.py create mode 100644 src/mlia/utils/__init__.py create mode 100644 src/mlia/utils/console.py create mode 100644 src/mlia/utils/download.py create mode 100644 src/mlia/utils/filesystem.py create mode 100644 src/mlia/utils/logging.py create mode 100644 src/mlia/utils/misc.py create mode 100644 src/mlia/utils/proc.py create mode 100644 src/mlia/utils/types.py create mode 100644 tests/__init__.py create mode 100644 tests/aiet/__init__.py create mode 100644 tests/aiet/conftest.py create mode 100644 tests/aiet/test_backend_application.py create mode 100644 tests/aiet/test_backend_common.py create mode 100644 tests/aiet/test_backend_controller.py create mode 100644 tests/aiet/test_backend_execution.py create mode 100644 tests/aiet/test_backend_output_parser.py create mode 100644 tests/aiet/test_backend_protocol.py create mode 100644 tests/aiet/test_backend_source.py create mode 100644 tests/aiet/test_backend_system.py create mode 100644 tests/aiet/test_backend_tool.py create mode 100644 tests/aiet/test_check_model.py create mode 100644 tests/aiet/test_cli.py create mode 100644 tests/aiet/test_cli_application.py create mode 100644 tests/aiet/test_cli_common.py create mode 100644 tests/aiet/test_cli_system.py create mode 100644 tests/aiet/test_cli_tool.py create mode 100644 tests/aiet/test_main.py create mode 100644 tests/aiet/test_resources/application_config.json create mode 100644 tests/aiet/test_resources/application_config.json.license create mode 100644 tests/aiet/test_resources/applications/application1/aiet-config.json create mode 100644 tests/aiet/test_resources/applications/application1/aiet-config.json.license create mode 100644 tests/aiet/test_resources/applications/application2/aiet-config.json create mode 100644 tests/aiet/test_resources/applications/application2/aiet-config.json.license create mode 100644 tests/aiet/test_resources/applications/application3/readme.txt create mode 100644 tests/aiet/test_resources/applications/application4/aiet-config.json create mode 100644 tests/aiet/test_resources/applications/application4/aiet-config.json.license create mode 100644 tests/aiet/test_resources/applications/application4/hello_app.txt create mode 100644 tests/aiet/test_resources/applications/application5/aiet-config.json create mode 100644 tests/aiet/test_resources/applications/application5/aiet-config.json.license create mode 100644 tests/aiet/test_resources/applications/readme.txt create mode 100644 tests/aiet/test_resources/hello_world.json create mode 100644 tests/aiet/test_resources/hello_world.json.license create mode 100755 tests/aiet/test_resources/scripts/test_backend_run create mode 100644 tests/aiet/test_resources/scripts/test_backend_run_script.sh create mode 100644 tests/aiet/test_resources/systems/system1/aiet-config.json create mode 100644 tests/aiet/test_resources/systems/system1/aiet-config.json.license create mode 100644 tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt create mode 100644 tests/aiet/test_resources/systems/system2/aiet-config.json create mode 100644 tests/aiet/test_resources/systems/system2/aiet-config.json.license create mode 100644 tests/aiet/test_resources/systems/system3/readme.txt create mode 100644 tests/aiet/test_resources/systems/system4/aiet-config.json create mode 100644 tests/aiet/test_resources/systems/system4/aiet-config.json.license create mode 100644 tests/aiet/test_resources/tools/tool1/aiet-config.json create mode 100644 tests/aiet/test_resources/tools/tool1/aiet-config.json.license create mode 100644 tests/aiet/test_resources/tools/tool2/aiet-config.json create mode 100644 tests/aiet/test_resources/tools/tool2/aiet-config.json.license create mode 100644 tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json create mode 100644 tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license create mode 100644 tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json create mode 100644 tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license create mode 100644 tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json create mode 100644 tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license create mode 100644 tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json create mode 100644 tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license create mode 100644 tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json create mode 100644 tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license create mode 100644 tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json create mode 100644 tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license create mode 100644 tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json create mode 100644 tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license create mode 100644 tests/aiet/test_run_vela_script.py create mode 100644 tests/aiet/test_utils_fs.py create mode 100644 tests/aiet/test_utils_helpers.py create mode 100644 tests/aiet/test_utils_proc.py create mode 100644 tests/conftest.py create mode 100644 tests/mlia/__init__.py create mode 100644 tests/mlia/conftest.py create mode 100644 tests/mlia/test_api.py create mode 100644 tests/mlia/test_cli_commands.py create mode 100644 tests/mlia/test_cli_config.py create mode 100644 tests/mlia/test_cli_helpers.py create mode 100644 tests/mlia/test_cli_logging.py create mode 100644 tests/mlia/test_cli_main.py create mode 100644 tests/mlia/test_cli_options.py create mode 100644 tests/mlia/test_core_advice_generation.py create mode 100644 tests/mlia/test_core_advisor.py create mode 100644 tests/mlia/test_core_context.py create mode 100644 tests/mlia/test_core_data_analysis.py create mode 100644 tests/mlia/test_core_events.py create mode 100644 tests/mlia/test_core_helpers.py create mode 100644 tests/mlia/test_core_mixins.py create mode 100644 tests/mlia/test_core_performance.py create mode 100644 tests/mlia/test_core_reporting.py create mode 100644 tests/mlia/test_core_workflow.py create mode 100644 tests/mlia/test_devices_ethosu_advice_generation.py create mode 100644 tests/mlia/test_devices_ethosu_advisor.py create mode 100644 tests/mlia/test_devices_ethosu_config.py create mode 100644 tests/mlia/test_devices_ethosu_data_analysis.py create mode 100644 tests/mlia/test_devices_ethosu_data_collection.py create mode 100644 tests/mlia/test_devices_ethosu_performance.py create mode 100644 tests/mlia/test_devices_ethosu_reporters.py create mode 100644 tests/mlia/test_nn_tensorflow_config.py create mode 100644 tests/mlia/test_nn_tensorflow_optimizations_clustering.py create mode 100644 tests/mlia/test_nn_tensorflow_optimizations_pruning.py create mode 100644 tests/mlia/test_nn_tensorflow_optimizations_select.py create mode 100644 tests/mlia/test_nn_tensorflow_tflite_metrics.py create mode 100644 tests/mlia/test_nn_tensorflow_utils.py create mode 100644 tests/mlia/test_resources/vela/sample_vela.ini create mode 100644 tests/mlia/test_tools_aiet_wrapper.py create mode 100644 tests/mlia/test_tools_metadata_common.py create mode 100644 tests/mlia/test_tools_metadata_corstone.py create mode 100644 tests/mlia/test_tools_vela_wrapper.py create mode 100644 tests/mlia/test_utils_console.py create mode 100644 tests/mlia/test_utils_download.py create mode 100644 tests/mlia/test_utils_filesystem.py create mode 100644 tests/mlia/test_utils_logging.py create mode 100644 tests/mlia/test_utils_misc.py create mode 100644 tests/mlia/test_utils_proc.py create mode 100644 tests/mlia/test_utils_types.py create mode 100644 tests/mlia/utils/__init__.py create mode 100644 tests/mlia/utils/common.py create mode 100644 tests/mlia/utils/logging.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e1557d9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +*.pyc +*~ +\.coverage +\.eggs +build +dist +src/*.egg-info +.vscode +venv +e2e_config +mlia_output +report +.ipynb_checkpoints diff --git a/LICENSES/Apache-2.0.txt b/LICENSES/Apache-2.0.txt new file mode 100644 index 0000000..f433b1a --- /dev/null +++ b/LICENSES/Apache-2.0.txt @@ -0,0 +1,177 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS diff --git a/LICENSES/BSD-3-Clause.txt b/LICENSES/BSD-3-Clause.txt new file mode 100644 index 0000000..e7674f7 --- /dev/null +++ b/LICENSES/BSD-3-Clause.txt @@ -0,0 +1,26 @@ +Copyright + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/LICENSES/CC-PDDC.txt b/LICENSES/CC-PDDC.txt new file mode 100644 index 0000000..b4ecb41 --- /dev/null +++ b/LICENSES/CC-PDDC.txt @@ -0,0 +1,28 @@ +Creative Commons Public Domain Dedication and Certification + +The person or persons who have associated work with this document (the +"Dedicator" or "Certifier") hereby either (a) certifies that, to the best of +his knowledge, the work of authorship identified is in the public domain of the +country from which the work is published, or (b) hereby dedicates whatever +copyright the dedicators holds in the work of authorship identified below (the +"Work") to the public domain. A certifier, moreover, dedicates any copyright +interest he may have in the associated work, and for these purposes, is +described as a "dedicator" below. + +A certifier has taken reasonable steps to verify the copyright status of this +work. Certifier recognizes that his good faith efforts may not shield him from +liability if in fact the work certified is not in the public domain. + +Dedicator makes this dedication for the benefit of the public at large and to +the detriment of the Dedicator's heirs and successors. Dedicator intends this +dedication to be an overt act of relinquishment in perpetuity of all present +and future rights under copyright law, whether vested or contingent, in the +Work. Dedicator understands that such relinquishment of all rights includes +the relinquishment of all rights to enforce (by lawsuit or otherwise) those +copyrights in the Work. + +Dedicator recognizes that, once placed in the public domain, the Work may be +freely reproduced, distributed, transmitted, used, modified, built upon, or +otherwise exploited by anyone for any purpose, commercial or non-commercial, +and in any way, including by methods that have not yet been invented or +conceived. diff --git a/LICENSES/MIT.txt b/LICENSES/MIT.txt new file mode 100644 index 0000000..8aa2645 --- /dev/null +++ b/LICENSES/MIT.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) [year] [fullname] + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..0f012d2 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +include README.md +graft src/mlia/resources diff --git a/README.md b/README.md new file mode 100644 index 0000000..dd1fd05 --- /dev/null +++ b/README.md @@ -0,0 +1,312 @@ + +# ML Inference Advisor + +## Introduction + +This tool is used to help AI developers design and optimize neural network +models for efficient inference on Arm® targets by enabling performance analysis +and providing actionable advice early in the model development cycle. The final +advice can cover the operator list, performance analysis and suggestions for +model inference run on certain hardware before/after applying model optimization +(e.g. pruning, clustering, etc.). + +## Prerequisites and dependencies + +It is recommended to use virtual environments for MLIA installation, and a +typical setup for MLIA requires: + +* Ubuntu® 20.04.03 LTS (other OSs may work, the ML Inference Advisor has been + tested on this one specifically) +* Python® >= 3.8 +* Ethos™-U Vela dependencies (Linux® only) + * For more details, please refer to the + [prerequisites of Vela](https://pypi.org/project/ethos-u-vela/) + +## Backend installation + +### Generic case using Corstone™-300 as an example + +The ML Inference Advisor is designed to support multiple performance +estimators (backends) that could generate performance analysis for individual +types of hardware. In this guide, we use the backend for +Ethos™-U (Corstone™-300) as an example. + +The install command can automatically download the necessary components and +dependencies, install them and configure them properly. + +The usage is: + +```bash +mlia backend install --help +``` + +and the result looks like: + +positional arguments: + +* name: Name of the backend to install + +optional arguments: + +* -h/--help: Show this help message and exit +* --path PATH: Path to the installed backend +* --download: Download and install a backend +* --noninteractive: Non interactive mode with automatic confirmation of every action + +Some examples of the installation process are: + +```bash +# reply 'y' or 'n' when prompted to download and install a Corstone-300 +mlia backend install --download +# for downloading and installing a specific backend +mlia backend install Corstone-300 --download +# for installing backend from the path of your downloaded backend +mlia backend install --path your_local_path_for_the_installed_backend +``` + +Please note: Corstone™-300 used in the example above is available only +on the Linux® platform. + +After a successful installation of the backend(s), start using mlia in your +virtual environment. Please note: backends cannot be removed once installed. +Consider creating new environment and reinstall backends when needed. + +### Using Corstone™-310 + +For instructions on installing Corstone™-310, please refer to + + +## Usage + +After the initial setup, you can start the program by opening your terminal and +typing the following command: + +```bash +mlia [command] [arguments] +``` + +where [command] is to be substituted by one of the supported options, discussed in +the next section. + +To get a list of all available options, use: + +```bash +mlia --help +``` + +To get help on a specific command, use: + +```bash +mlia [command] --help +``` + +Choices of commands: you can use ["operators"](#operators-ops) command for the +model's operator list, run the specified optimizations using +["model optimization"](#model-optimization-opt) command, or measure the +performance of inference on hardware using ["performance"](#performance-perf) +command. In the end, you can use ["all tests"](#all-tests-all) command to +have a full report. + +Most commands accept the name of the target profile name as input parameter. +There are a number of predefined profiles with following attributes: + +``` ++--------------------------------------------------------------------+ +| Profile name | MAC | System config | Memory mode | ++===================================================================== +| ethos-u55-256 | 256 | Ethos_U55_High_End_Embedded | Shared_Sram | ++--------------------------------------------------------------------- +| ethos-u55-128 | 128 | Ethos_U55_High_End_Embedded | Shared_Sram | ++--------------------------------------------------------------------- +| ethos-u65-512 | 512 | Ethos_U65_High_End | Dedicated_Sram | ++--------------------------------------------------------------------+ +``` + +### **Operators** (ops) + +#### *Description* + +Prints the model's operator list. + +#### *Arguments* + +##### Optional arguments + +* -h/--help: Show the general help document and exit. +* --supported-ops-report: Generate the SUPPORTED_OPS.md file in the current working + directory and exit. + +##### Target profile options + +* --target-profile: Target profile that will set the target options such as + target, mac value, memory mode, etc ... + * default: ethos-u55-256 + * options: + * ethos-u55-256 + * ethos-u55-128 + * ethos-u65-512 + +##### TFLite model options + +* model: Input model in TFLite format [required]. + +##### Output options + +* --output: Name of the file where the report will be saved. + The report is also displayed the standard output, as plain text. + Valid file extensions (formats) are {.txt,.json,.csv}, + anything else will be formatted as plain text. + +### **Performance** (perf) + +#### *Description* + +Prints the model's performance statistics. + +#### *Arguments* + +##### optional arguments + +* -h/--help: Show the general help document and exit. +* --supported-ops-report: Generate the SUPPORTED_OPS.md file in the current + working directory and exit. + +##### Target profile options + +* --target-profile: Target profile that will set the target options such as + target, mac value, memory mode, etc ... + * default: ethos-u55-256 + * options: + * ethos-u55-256 + * ethos-u55-128 + * ethos-u65-512 + +##### TFLite model options + +* model: Input model in TFLite format [required]. + +##### Output options + +* --output: Name of the file where the report will be saved. + The report is also displayed the standard output, as plain text. + Valid file extensions (formats) are {.txt,.json,.csv}, + anything else will be formatted as plain text. + +##### Debug options + +* --verbose: Produce verbose output (for debugging purposes). + +### **Model optimization** (opt) + +#### *Description* + +Shows the performance improvements after applying optimizations to the model. + +#### *Arguments* + +##### optional arguments + +* -h/--help: Show the general help document and exit. +* --supported-ops-report: Generate the SUPPORTED_OPS.md file in the current + working directory and exit. + +##### Target profile options + +* --target-profile: Target profile that will set the target options such as + target, mac value, memory mode, etc ... + * default: ethos-u55-256 + * options: + * ethos-u55-256 + * ethos-u55-128 + * ethos-u65-512 + +##### Keras™ model options + +* model: Input model in Keras™ (.h5 or SavedModel) format [required]. + +##### optimization options + +* --optimization-type: Type of optimization to apply to the model [required]. + * options: + * pruning + * clustering +* --optimization-target: Target for optimization (for pruning this is sparsity + between (0,1), for clustering this is the number of clusters + (positive integer)) [required]. +* --layers-to-optimize: Name of the layers to optimize (separated by space). + Example: conv1 conv2 conv3 + * default: every layer + +##### Debug options + +* --verbose: Produce verbose output (for debugging purposes). + +### **All tests** (all) + +#### *Description* + +Generates a full report on the input model's operator list, +runs the specified optimizations and lists the performance improvements. + +#### *Arguments* + +##### Optional arguments + +* -h/--help: show this help message and exit + +##### Target profile options + +* --target-profile: Target profile that will set the target options such as + target, mac value, memory mode, etc ... + * default: ethos-u55-256 + * options: + * ethos-u55-256 + * ethos-u55-128 + * ethos-u65-512 + +##### Keras™ model options + +* model: Input model in Keras™ (.h5 or SavedModel) format [required]. + +##### Optimization options + +* --optimization-type: List of the optimization types separated by comma + * default: pruning, clustering +* --optimization-target: List of the optimization targets separated by comma, + (for pruning this is sparsity between (0,1), for clustering this is the + number of clusters (positive integer)) + * default: 0.5, 32 + +##### Output options + +* --output: Name of the file where the report will be saved. + The report is also displayed the standard output, as plain text. + Valid file extensions (formats) are {.txt,.json,.csv}, + anything else will be formatted as plain text. + +##### Debug options + +* --verbose: Produce verbose output (for debugging purposes). + +## Resources + +Additional useful information: + +* [Corstone™-300](https://developer.arm.com/Processors/Corstone-300) + +## License + +ML Inference Advisor is licensed under [Apache License 2.0](LICENSE.txt). + +## Trademarks and Copyrights + +Arm®, Ethos™-U, Cortex®-M, Corstone™ are registered trademarks or trademarks +of Arm® Limited (or its subsidiaries) in the U.S. and/or elsewhere. +TensorFlow™ is a trademark of Google® LLC. +Keras™ is a trademark by François Chollet. +Linux® is the registered trademark of Linus Torvalds in the U.S. and elsewhere. +Python® is a registered trademark of the PSF. +Ubuntu® is a registered trademark of Canonical. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..05363d8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright (c) 2012-2022 Jukka Lehtosalo and contributors +# SPDX-FileCopyrightText: Copyright (c) 2015-2022 Dropbox, Inc. +# SPDX-License-Identifier: Apache-2.0 AND MIT + + +[build-system] +requires = [ + "setuptools>=42", + "wheel", + "setuptools_scm[toml]>=6.2" +] +build-backend = "setuptools.build_meta" + +# Enable setuptools_scm +[tool.setuptools_scm] + +[tool.pytest.ini_options] +testpaths = "tests" +markers = [ + "e2e", # e2e tests + "install", # installation tests + "command", # command tests + "model_gen" # model generation tests +] +junit_logging = "all" + +[tool.pylint.messages_control] +min-similarity-lines = 10 +min-public-methods = 1 +max-line-length = 88 +max-args = 8 +max-attributes=10 + +# Provide basic compatibility with black +disable = [ + "wrong-import-order", + "consider-using-f-string" # C0209 +] + +enable = [ + "dangerous-default-value", # W0102 + # black will reflow code lines, but won't touch comments, error on those + "line-too-long" # C0301 +] + +[tool.mypy] +# Suppresses error messages about imports that cannot be resolved +ignore_missing_imports = true +# Shows a warning when encountering any code inferred to be unreachable or redundant after performing type analysis +warn_unreachable = true +# Shows errors for missing return statements on some execution paths +warn_no_return = true +# Shows a warning when returning a value with type Any from a function declared with a non- Any return type +warn_return_any = true +# Warns about unneeded # type: ignore comments +warn_unused_ignores = true +# Warns about casting an expression to its inferred type +warn_redundant_casts = true +# Disallows calling functions without type annotations from functions with type annotations +disallow_untyped_calls = true +# Disallows defining functions without type annotations or with incomplete type annotations +disallow_untyped_defs = true +# Disallows defining functions with incomplete type annotations +disallow_incomplete_defs = true +# Reports an error whenever a function with type annotations is decorated with a decorator without annotations +disallow_untyped_decorators = true +# Type-checks the interior of functions without type annotations +check_untyped_defs = true + +[[tool.mypy.overrides]] +module = [ + "pkg_resources", + "paramiko", + "requests", + "filelock" +] +ignore_missing_imports = true diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..5a9202a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright (c) 2020 Troy Comi +# SPDX-License-Identifier: Apache-2.0 AND MIT + +[metadata] +name = mlia +description = ML Inference Advisor +long_description = file: README.md +url = https://git.mlplatform.org/ml/mlia.git +author = Arm Ltd +author_email = mlia@arm.com +license = Apache License 2.0 +license_files = LICENSES/*.txt +classifiers = + Development Status :: 4 - Beta + License :: OSI Approved :: Apache Software License + Intended Audience :: Developers + Operating System :: POSIX :: Linux + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.8 + Topic :: Scientific/Engineering :: Artificial Intelligence +keywords = ml, arm, ethos-u, tflite + +[options] +include_package_data = True +python_requires = >=3.8 +package_dir = + = src +packages = find: +install_requires = + tensorflow~=2.7.1 + tensorflow-model-optimization~=0.7.2 + ethos-u-vela~=3.3.0 + requests + rich + click + sh + paramiko + filelock + psutil + cloup>=0.12.0 + +[options.packages.find] +where = src + +[options.entry_points] +console_scripts = + mlia=mlia.cli.main:main + aiet=aiet.main:main + run_vela=aiet.resources.tools.vela.run_vela:main + +[options.extras_require] +dev = + pytest==7.1.1 + pytest-cov==3.0.0 + mypy==0.942 + pylint==2.13.7 + +[flake8] +# ignored errors +# E501 line too long +# W503 line break before binary operator +ignore = E501, W503 +max-complexity = 18 +select = B,C,E,F,W,T4 + +[blocklint] +# Do not allow any non-inclusive language +max_issue_threshold=1 +# Blocklist: Words to lint in any context, with possibly special characters +# between, case insensitive +blocklist=master,slave,blacklist,whitelist +# Word list: Words to lint as whole words, with possibly special characters +# between, case insensitive +wordlist=he,she,him,her,his,hers +# Exact list: Words to lint as whole words exactly as entered +# exactlist= +# Files that should not be checked by blocklint. +skip_files=LICENSES/CC-PDDC.txt diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..4707ea5 --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module to setup the python package.""" +from setuptools import setup + + +setup() diff --git a/src/aiet/__init__.py b/src/aiet/__init__.py new file mode 100644 index 0000000..49304b1 --- /dev/null +++ b/src/aiet/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Init of aiet.""" +import pkg_resources + + +__version__ = pkg_resources.get_distribution("mlia").version diff --git a/src/aiet/backend/__init__.py b/src/aiet/backend/__init__.py new file mode 100644 index 0000000..3d60372 --- /dev/null +++ b/src/aiet/backend/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Backend module.""" diff --git a/src/aiet/backend/application.py b/src/aiet/backend/application.py new file mode 100644 index 0000000..f6ef815 --- /dev/null +++ b/src/aiet/backend/application.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Application backend module.""" +import re +from pathlib import Path +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional + +from aiet.backend.common import Backend +from aiet.backend.common import ConfigurationException +from aiet.backend.common import DataPaths +from aiet.backend.common import get_backend_configs +from aiet.backend.common import get_backend_directories +from aiet.backend.common import load_application_or_tool_configs +from aiet.backend.common import load_config +from aiet.backend.common import remove_backend +from aiet.backend.config import ApplicationConfig +from aiet.backend.config import ExtendedApplicationConfig +from aiet.backend.source import create_destination_and_install +from aiet.backend.source import get_source +from aiet.utils.fs import get_resources + + +def get_available_application_directory_names() -> List[str]: + """Return a list of directory names for all available applications.""" + return [entry.name for entry in get_backend_directories("applications")] + + +def get_available_applications() -> List["Application"]: + """Return a list with all available applications.""" + available_applications = [] + for config_json in get_backend_configs("applications"): + config_entries = cast(List[ExtendedApplicationConfig], load_config(config_json)) + for config_entry in config_entries: + config_entry["config_location"] = config_json.parent.absolute() + applications = load_applications(config_entry) + available_applications += applications + + return sorted(available_applications, key=lambda application: application.name) + + +def get_application( + application_name: str, system_name: Optional[str] = None +) -> List["Application"]: + """Return a list of application instances with provided name.""" + return [ + application + for application in get_available_applications() + if application.name == application_name + and (not system_name or application.can_run_on(system_name)) + ] + + +def install_application(source_path: Path) -> None: + """Install application.""" + try: + source = get_source(source_path) + config = cast(List[ExtendedApplicationConfig], source.config()) + applications_to_install = [ + s for entry in config for s in load_applications(entry) + ] + except Exception as error: + raise ConfigurationException("Unable to read application definition") from error + + if not applications_to_install: + raise ConfigurationException("No application definition found") + + available_applications = get_available_applications() + already_installed = [ + s for s in applications_to_install if s in available_applications + ] + if already_installed: + names = {application.name for application in already_installed} + raise ConfigurationException( + "Applications [{}] are already installed".format(",".join(names)) + ) + + create_destination_and_install(source, get_resources("applications")) + + +def remove_application(directory_name: str) -> None: + """Remove application directory.""" + remove_backend(directory_name, "applications") + + +def get_unique_application_names(system_name: Optional[str] = None) -> List[str]: + """Extract a list of unique application names of all application available.""" + return list( + set( + application.name + for application in get_available_applications() + if not system_name or application.can_run_on(system_name) + ) + ) + + +class Application(Backend): + """Class for representing a single application component.""" + + def __init__(self, config: ApplicationConfig) -> None: + """Construct a Application instance from a dict.""" + super().__init__(config) + + self.supported_systems = config.get("supported_systems", []) + self.deploy_data = config.get("deploy_data", []) + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Application): + return False + + return ( + super().__eq__(other) + and self.name == other.name + and set(self.supported_systems) == set(other.supported_systems) + ) + + def can_run_on(self, system_name: str) -> bool: + """Check if the application can run on the system passed as argument.""" + return system_name in self.supported_systems + + def get_deploy_data(self) -> List[DataPaths]: + """Validate and return data specified in the config file.""" + if self.config_location is None: + raise ConfigurationException( + "Unable to get application {} config location".format(self.name) + ) + + deploy_data = [] + for item in self.deploy_data: + src, dst = item + src_full_path = self.config_location / src + assert src_full_path.exists(), "{} does not exists".format(src_full_path) + deploy_data.append(DataPaths(src_full_path, dst)) + return deploy_data + + def get_details(self) -> Dict[str, Any]: + """Return dictionary with information about the Application instance.""" + output = { + "type": "application", + "name": self.name, + "description": self.description, + "supported_systems": self.supported_systems, + "commands": self._get_command_details(), + } + + return output + + def remove_unused_params(self) -> None: + """Remove unused params in commands. + + After merging default and system related configuration application + could have parameters that are not being used in commands. They + should be removed. + """ + for command in self.commands.values(): + indexes_or_aliases = [ + m + for cmd_str in command.command_strings + for m in re.findall(r"{user_params:(?P\w+)}", cmd_str) + ] + + only_aliases = all(not item.isnumeric() for item in indexes_or_aliases) + if only_aliases: + used_params = [ + param + for param in command.params + if param.alias in indexes_or_aliases + ] + command.params = used_params + + +def load_applications(config: ExtendedApplicationConfig) -> List[Application]: + """Load application. + + Application configuration could contain different parameters/commands for different + supported systems. For each supported system this function will return separate + Application instance with appropriate configuration. + """ + configs = load_application_or_tool_configs(config, ApplicationConfig) + applications = [Application(cfg) for cfg in configs] + for application in applications: + application.remove_unused_params() + return applications diff --git a/src/aiet/backend/common.py b/src/aiet/backend/common.py new file mode 100644 index 0000000..b887ee7 --- /dev/null +++ b/src/aiet/backend/common.py @@ -0,0 +1,532 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain all common functions for the backends.""" +import json +import logging +import re +from abc import ABC +from collections import Counter +from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Final +from typing import IO +from typing import Iterable +from typing import List +from typing import Match +from typing import NamedTuple +from typing import Optional +from typing import Pattern +from typing import Tuple +from typing import Type +from typing import Union + +from aiet.backend.config import BackendConfig +from aiet.backend.config import BaseBackendConfig +from aiet.backend.config import NamedExecutionConfig +from aiet.backend.config import UserParamConfig +from aiet.backend.config import UserParamsConfig +from aiet.utils.fs import get_resources +from aiet.utils.fs import remove_resource +from aiet.utils.fs import ResourceType + + +AIET_CONFIG_FILE: Final[str] = "aiet-config.json" + + +class ConfigurationException(Exception): + """Configuration exception.""" + + +def get_backend_config(dir_path: Path) -> Path: + """Get path to backendir configuration file.""" + return dir_path / AIET_CONFIG_FILE + + +def get_backend_configs(resource_type: ResourceType) -> Iterable[Path]: + """Get path to the backend configs for provided resource_type.""" + return ( + get_backend_config(entry) for entry in get_backend_directories(resource_type) + ) + + +def get_backend_directories(resource_type: ResourceType) -> Iterable[Path]: + """Get path to the backend directories for provided resource_type.""" + return ( + entry + for entry in get_resources(resource_type).iterdir() + if is_backend_directory(entry) + ) + + +def is_backend_directory(dir_path: Path) -> bool: + """Check if path is backend's configuration directory.""" + return dir_path.is_dir() and get_backend_config(dir_path).is_file() + + +def remove_backend(directory_name: str, resource_type: ResourceType) -> None: + """Remove backend with provided type and directory_name.""" + if not directory_name: + raise Exception("No directory name provided") + + remove_resource(directory_name, resource_type) + + +def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig: + """Return a loaded json file.""" + if config is None: + raise Exception("Unable to read config") + + if isinstance(config, Path): + with config.open() as json_file: + return cast(BackendConfig, json.load(json_file)) + + return cast(BackendConfig, json.load(config)) + + +def parse_raw_parameter(parameter: str) -> Tuple[str, Optional[str]]: + """Split the parameter string in name and optional value. + + It manages the following cases: + --param=1 -> --param, 1 + --param 1 -> --param, 1 + --flag -> --flag, None + """ + data = re.split(" |=", parameter) + if len(data) == 1: + param_name = data[0] + param_value = None + else: + param_name = " ".join(data[0:-1]) + param_value = data[-1] + return param_name, param_value + + +class DataPaths(NamedTuple): + """DataPaths class.""" + + src: Path + dst: str + + +class Backend(ABC): + """Backend class.""" + + # pylint: disable=too-many-instance-attributes + + def __init__(self, config: BaseBackendConfig): + """Initialize backend.""" + name = config.get("name") + if not name: + raise ConfigurationException("Name is empty") + + self.name = name + self.description = config.get("description", "") + self.config_location = config.get("config_location") + self.variables = config.get("variables", {}) + self.build_dir = config.get("build_dir") + self.lock = config.get("lock", False) + if self.build_dir: + self.build_dir = self._substitute_variables(self.build_dir) + self.annotations = config.get("annotations", {}) + + self._parse_commands_and_params(config) + + def validate_parameter(self, command_name: str, parameter: str) -> bool: + """Validate the parameter string against the application configuration. + + We take the parameter string, extract the parameter name/value and + check them against the current configuration. + """ + param_name, param_value = parse_raw_parameter(parameter) + valid_param_name = valid_param_value = False + + command = self.commands.get(command_name) + if not command: + raise AttributeError("Unknown command: '{}'".format(command_name)) + + # Iterate over all available parameters until we have a match. + for param in command.params: + if self._same_parameter(param_name, param): + valid_param_name = True + # This is a non-empty list + if param.values: + # We check if the value is allowed in the configuration + valid_param_value = param_value in param.values + else: + # In this case we don't validate the value and accept + # whatever we have set. + valid_param_value = True + break + + return valid_param_name and valid_param_value + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Backend): + return False + + return ( + self.name == other.name + and self.description == other.description + and self.commands == other.commands + ) + + def __repr__(self) -> str: + """Represent the Backend instance by its name.""" + return self.name + + def _parse_commands_and_params(self, config: BaseBackendConfig) -> None: + """Parse commands and user parameters.""" + self.commands: Dict[str, Command] = {} + + commands = config.get("commands") + if commands: + params = config.get("user_params") + + for command_name in commands.keys(): + command_params = self._parse_params(params, command_name) + command_strings = [ + self._substitute_variables(cmd) + for cmd in commands.get(command_name, []) + ] + self.commands[command_name] = Command(command_strings, command_params) + + def _substitute_variables(self, str_val: str) -> str: + """Substitute variables in string. + + Variables is being substituted at backend's creation stage because + they could contain references to other params which will be + resolved later. + """ + if not str_val: + return str_val + + var_pattern: Final[Pattern] = re.compile(r"{variables:(?P\w+)}") + + def var_value(match: Match) -> str: + var_name = match["var_name"] + if var_name not in self.variables: + raise ConfigurationException("Unknown variable {}".format(var_name)) + + return self.variables[var_name] + + return var_pattern.sub(var_value, str_val) # type: ignore + + @classmethod + def _parse_params( + cls, params: Optional[UserParamsConfig], command: str + ) -> List["Param"]: + if not params: + return [] + + return [cls._parse_param(p) for p in params.get(command, [])] + + @classmethod + def _parse_param(cls, param: UserParamConfig) -> "Param": + """Parse a single parameter.""" + name = param.get("name") + if name is not None and not name: + raise ConfigurationException("Parameter has an empty 'name' attribute.") + values = param.get("values", None) + default_value = param.get("default_value", None) + description = param.get("description", "") + alias = param.get("alias") + + return Param( + name=name, + description=description, + values=values, + default_value=default_value, + alias=alias, + ) + + def _get_command_details(self) -> Dict: + command_details = { + command_name: command.get_details() + for command_name, command in self.commands.items() + } + return command_details + + def _get_user_param_value( + self, user_params: List[str], param: "Param" + ) -> Optional[str]: + """Get the user-specified value of a parameter.""" + for user_param in user_params: + user_param_name, user_param_value = parse_raw_parameter(user_param) + if user_param_name == param.name: + warn_message = ( + "The direct use of parameter name is deprecated" + " and might be removed in the future.\n" + f"Please use alias '{param.alias}' instead of " + "'{user_param_name}' to provide the parameter." + ) + logging.warning(warn_message) + + if self._same_parameter(user_param_name, param): + return user_param_value + + return None + + @staticmethod + def _same_parameter(user_param_name_or_alias: str, param: "Param") -> bool: + """Compare user parameter name with param name or alias.""" + # Strip the "=" sign in the param_name. This is needed just for + # comparison with the parameters passed by the user. + # The equal sign needs to be honoured when re-building the + # parameter back. + param_name = None if not param.name else param.name.rstrip("=") + return user_param_name_or_alias in [param_name, param.alias] + + def resolved_parameters( + self, command_name: str, user_params: List[str] + ) -> List[Tuple[Optional[str], "Param"]]: + """Return list of parameters with values.""" + result: List[Tuple[Optional[str], "Param"]] = [] + command = self.commands.get(command_name) + if not command: + return result + + for param in command.params: + value = self._get_user_param_value(user_params, param) + if not value: + value = param.default_value + result.append((value, param)) + + return result + + def build_command( + self, + command_name: str, + user_params: List[str], + param_resolver: Callable[[str, str, List[Tuple[Optional[str], "Param"]]], str], + ) -> List[str]: + """ + Return a list of executable command strings. + + Given a command and associated parameters, returns a list of executable command + strings. + """ + command = self.commands.get(command_name) + if not command: + raise ConfigurationException( + "Command '{}' could not be found.".format(command_name) + ) + + commands_to_run = [] + + params_values = self.resolved_parameters(command_name, user_params) + for cmd_str in command.command_strings: + cmd_str = resolve_all_parameters( + cmd_str, param_resolver, command_name, params_values + ) + commands_to_run.append(cmd_str) + + return commands_to_run + + +class Param: + """Class for representing a generic application parameter.""" + + def __init__( # pylint: disable=too-many-arguments + self, + name: Optional[str], + description: str, + values: Optional[List[str]] = None, + default_value: Optional[str] = None, + alias: Optional[str] = None, + ) -> None: + """Construct a Param instance.""" + if not name and not alias: + raise ConfigurationException( + "Either name, alias or both must be set to identify a parameter." + ) + self.name = name + self.values = values + self.description = description + self.default_value = default_value + self.alias = alias + + def get_details(self) -> Dict: + """Return a dictionary with all relevant information of a Param.""" + return {key: value for key, value in self.__dict__.items() if value} + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Param): + return False + + return ( + self.name == other.name + and self.values == other.values + and self.default_value == other.default_value + and self.description == other.description + ) + + +class Command: + """Class for representing a command.""" + + def __init__( + self, command_strings: List[str], params: Optional[List[Param]] = None + ) -> None: + """Construct a Command instance.""" + self.command_strings = command_strings + + if params: + self.params = params + else: + self.params = [] + + self._validate() + + def _validate(self) -> None: + """Validate command.""" + if not self.params: + return + + aliases = [param.alias for param in self.params if param.alias is not None] + repeated_aliases = [ + alias for alias, count in Counter(aliases).items() if count > 1 + ] + + if repeated_aliases: + raise ConfigurationException( + "Non unique aliases {}".format(", ".join(repeated_aliases)) + ) + + both_name_and_alias = [ + param.name + for param in self.params + if param.name in aliases and param.name != param.alias + ] + if both_name_and_alias: + raise ConfigurationException( + "Aliases {} could not be used as parameter name".format( + ", ".join(both_name_and_alias) + ) + ) + + def get_details(self) -> Dict: + """Return a dictionary with all relevant information of a Command.""" + output = { + "command_strings": self.command_strings, + "user_params": [param.get_details() for param in self.params], + } + return output + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Command): + return False + + return ( + self.command_strings == other.command_strings + and self.params == other.params + ) + + +def resolve_all_parameters( + str_val: str, + param_resolver: Callable[[str, str, List[Tuple[Optional[str], Param]]], str], + command_name: Optional[str] = None, + params_values: Optional[List[Tuple[Optional[str], Param]]] = None, +) -> str: + """Resolve all parameters in the string.""" + if not str_val: + return str_val + + param_pattern: Final[Pattern] = re.compile(r"{(?P[\w.:]+)}") + while param_pattern.findall(str_val): + str_val = param_pattern.sub( + lambda m: param_resolver( + m["param_name"], command_name or "", params_values or [] + ), + str_val, + ) + return str_val + + +def load_application_or_tool_configs( + config: Any, + config_type: Type[Any], + is_system_required: bool = True, +) -> Any: + """Get one config for each system supported by the application/tool. + + The configuration could contain different parameters/commands for different + supported systems. For each supported system this function will return separate + config with appropriate configuration. + """ + merged_configs = [] + supported_systems: Optional[List[NamedExecutionConfig]] = config.get( + "supported_systems" + ) + if not supported_systems: + if is_system_required: + raise ConfigurationException("No supported systems definition provided") + # Create an empty system to be used in the parsing below + supported_systems = [cast(NamedExecutionConfig, {})] + + default_user_params = config.get("user_params", {}) + + def merge_config(system: NamedExecutionConfig) -> Any: + system_name = system.get("name") + if not system_name and is_system_required: + raise ConfigurationException( + "Unable to read supported system definition, name is missed" + ) + + merged_config = config_type(**config) + merged_config["supported_systems"] = [system_name] if system_name else [] + # merge default configuration and specific to the system + merged_config["commands"] = { + **config.get("commands", {}), + **system.get("commands", {}), + } + + params = {} + tool_user_params = system.get("user_params", {}) + command_names = tool_user_params.keys() | default_user_params.keys() + for command_name in command_names: + if command_name not in merged_config["commands"]: + continue + + params_default = default_user_params.get(command_name, []) + params_tool = tool_user_params.get(command_name, []) + if not params_default or not params_tool: + params[command_name] = params_tool or params_default + if params_default and params_tool: + if any(not p.get("alias") for p in params_default): + raise ConfigurationException( + "Default parameters for command {} should have aliases".format( + command_name + ) + ) + if any(not p.get("alias") for p in params_tool): + raise ConfigurationException( + "{} parameters for command {} should have aliases".format( + system_name, command_name + ) + ) + + merged_by_alias = { + **{p.get("alias"): p for p in params_default}, + **{p.get("alias"): p for p in params_tool}, + } + params[command_name] = list(merged_by_alias.values()) + + merged_config["user_params"] = params + merged_config["build_dir"] = system.get("build_dir", config.get("build_dir")) + merged_config["lock"] = system.get("lock", config.get("lock", False)) + merged_config["variables"] = { + **config.get("variables", {}), + **system.get("variables", {}), + } + return merged_config + + merged_configs = [merge_config(system) for system in supported_systems] + + return merged_configs diff --git a/src/aiet/backend/config.py b/src/aiet/backend/config.py new file mode 100644 index 0000000..dd42012 --- /dev/null +++ b/src/aiet/backend/config.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain definition of backend configuration.""" +from pathlib import Path +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import Tuple +from typing import TypedDict +from typing import Union + + +class UserParamConfig(TypedDict, total=False): + """User parameter configuration.""" + + name: Optional[str] + default_value: str + values: List[str] + description: str + alias: str + + +UserParamsConfig = Dict[str, List[UserParamConfig]] + + +class ExecutionConfig(TypedDict, total=False): + """Execution configuration.""" + + commands: Dict[str, List[str]] + user_params: UserParamsConfig + build_dir: str + variables: Dict[str, str] + lock: bool + + +class NamedExecutionConfig(ExecutionConfig): + """Execution configuration with name.""" + + name: str + + +class BaseBackendConfig(ExecutionConfig, total=False): + """Base backend configuration.""" + + name: str + description: str + config_location: Path + annotations: Dict[str, Union[str, List[str]]] + + +class ApplicationConfig(BaseBackendConfig, total=False): + """Application configuration.""" + + supported_systems: List[str] + deploy_data: List[Tuple[str, str]] + + +class ExtendedApplicationConfig(BaseBackendConfig, total=False): + """Extended application configuration.""" + + supported_systems: List[NamedExecutionConfig] + deploy_data: List[Tuple[str, str]] + + +class ProtocolConfig(TypedDict, total=False): + """Protocol config.""" + + protocol: Literal["local", "ssh"] + + +class SSHConfig(ProtocolConfig, total=False): + """SSH configuration.""" + + username: str + password: str + hostname: str + port: str + + +class LocalProtocolConfig(ProtocolConfig, total=False): + """Local protocol config.""" + + +class SystemConfig(BaseBackendConfig, total=False): + """System configuration.""" + + data_transfer: Union[SSHConfig, LocalProtocolConfig] + reporting: Dict[str, Dict] + + +class ToolConfig(BaseBackendConfig, total=False): + """Tool configuration.""" + + supported_systems: List[str] + + +class ExtendedToolConfig(BaseBackendConfig, total=False): + """Extended tool configuration.""" + + supported_systems: List[NamedExecutionConfig] + + +BackendItemConfig = Union[ApplicationConfig, SystemConfig, ToolConfig] +BackendConfig = Union[ + List[ExtendedApplicationConfig], List[SystemConfig], List[ToolConfig] +] diff --git a/src/aiet/backend/controller.py b/src/aiet/backend/controller.py new file mode 100644 index 0000000..2650902 --- /dev/null +++ b/src/aiet/backend/controller.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Controller backend module.""" +import time +from pathlib import Path +from typing import List +from typing import Optional +from typing import Tuple + +import psutil +import sh + +from aiet.backend.common import ConfigurationException +from aiet.utils.fs import read_file_as_string +from aiet.utils.proc import execute_command +from aiet.utils.proc import get_stdout_stderr_paths +from aiet.utils.proc import read_process_info +from aiet.utils.proc import save_process_info +from aiet.utils.proc import terminate_command +from aiet.utils.proc import terminate_external_process + + +class SystemController: + """System controller class.""" + + def __init__(self) -> None: + """Create new instance of service controller.""" + self.cmd: Optional[sh.RunningCommand] = None + self.out_path: Optional[Path] = None + self.err_path: Optional[Path] = None + + def before_start(self) -> None: + """Run actions before system start.""" + + def after_start(self) -> None: + """Run actions after system start.""" + + def start(self, commands: List[str], cwd: Path) -> None: + """Start system.""" + if not isinstance(cwd, Path) or not cwd.is_dir(): + raise ConfigurationException("Wrong working directory {}".format(cwd)) + + if len(commands) != 1: + raise ConfigurationException("System should have only one command to run") + + startup_command = commands[0] + if not startup_command: + raise ConfigurationException("No startup command provided") + + self.before_start() + + self.out_path, self.err_path = get_stdout_stderr_paths(startup_command) + + self.cmd = execute_command( + startup_command, + cwd, + bg=True, + out=str(self.out_path), + err=str(self.err_path), + ) + + self.after_start() + + def stop( + self, wait: bool = False, wait_period: float = 0.5, number_of_attempts: int = 20 + ) -> None: + """Stop system.""" + if self.cmd is not None and self.is_running(): + terminate_command(self.cmd, wait, wait_period, number_of_attempts) + + def is_running(self) -> bool: + """Check if underlying process is running.""" + return self.cmd is not None and self.cmd.is_alive() + + def get_output(self) -> Tuple[str, str]: + """Return application output.""" + if self.cmd is None or self.out_path is None or self.err_path is None: + return ("", "") + + return (read_file_as_string(self.out_path), read_file_as_string(self.err_path)) + + +class SystemControllerSingleInstance(SystemController): + """System controller with support of system's single instance.""" + + def __init__(self, pid_file_path: Optional[Path] = None) -> None: + """Create new instance of the service controller.""" + super().__init__() + self.pid_file_path = pid_file_path + + def before_start(self) -> None: + """Run actions before system start.""" + self._check_if_previous_instance_is_running() + + def after_start(self) -> None: + """Run actions after system start.""" + self._save_process_info() + + def _check_if_previous_instance_is_running(self) -> None: + """Check if another instance of the system is running.""" + process_info = read_process_info(self._pid_file()) + + for item in process_info: + try: + process = psutil.Process(item.pid) + same_process = ( + process.name() == item.name + and process.exe() == item.executable + and process.cwd() == item.cwd + ) + if same_process: + print( + "Stopping previous instance of the system [{}]".format(item.pid) + ) + terminate_external_process(process) + except psutil.NoSuchProcess: + pass + + def _save_process_info(self, wait_period: float = 2) -> None: + """Save information about system's processes.""" + if self.cmd is None or not self.is_running(): + return + + # give some time for the system to start + time.sleep(wait_period) + + save_process_info(self.cmd.process.pid, self._pid_file()) + + def _pid_file(self) -> Path: + """Return path to file which is used for saving process info.""" + if not self.pid_file_path: + raise Exception("No pid file path presented") + + return self.pid_file_path diff --git a/src/aiet/backend/execution.py b/src/aiet/backend/execution.py new file mode 100644 index 0000000..1653ee2 --- /dev/null +++ b/src/aiet/backend/execution.py @@ -0,0 +1,859 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Application execution module.""" +import itertools +import json +import random +import re +import string +import sys +import time +import warnings +from collections import defaultdict +from contextlib import contextmanager +from contextlib import ExitStack +from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import ContextManager +from typing import Dict +from typing import Generator +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TypedDict +from typing import Union + +from filelock import FileLock +from filelock import Timeout + +from aiet.backend.application import Application +from aiet.backend.application import get_application +from aiet.backend.common import Backend +from aiet.backend.common import ConfigurationException +from aiet.backend.common import DataPaths +from aiet.backend.common import Param +from aiet.backend.common import parse_raw_parameter +from aiet.backend.common import resolve_all_parameters +from aiet.backend.output_parser import Base64OutputParser +from aiet.backend.output_parser import OutputParser +from aiet.backend.output_parser import RegexOutputParser +from aiet.backend.system import ControlledSystem +from aiet.backend.system import get_system +from aiet.backend.system import StandaloneSystem +from aiet.backend.system import System +from aiet.backend.tool import get_tool +from aiet.backend.tool import Tool +from aiet.utils.fs import recreate_directory +from aiet.utils.fs import remove_directory +from aiet.utils.fs import valid_for_filename +from aiet.utils.proc import run_and_wait + + +class AnotherInstanceIsRunningException(Exception): + """Concurrent execution error.""" + + +class ConnectionException(Exception): + """Connection exception.""" + + +class ExecutionParams(TypedDict, total=False): + """Execution parameters.""" + + disable_locking: bool + unique_build_dir: bool + + +class ExecutionContext: + """Command execution context.""" + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + app: Union[Application, Tool], + app_params: List[str], + system: Optional[System], + system_params: List[str], + custom_deploy_data: Optional[List[DataPaths]] = None, + execution_params: Optional[ExecutionParams] = None, + report_file: Optional[Path] = None, + ): + """Init execution context.""" + self.app = app + self.app_params = app_params + self.custom_deploy_data = custom_deploy_data or [] + self.system = system + self.system_params = system_params + self.execution_params = execution_params or ExecutionParams() + self.report_file = report_file + + self.reporter: Optional[Reporter] + if self.report_file: + # Create reporter with output parsers + parsers: List[OutputParser] = [] + if system and system.reporting: + # Add RegexOutputParser, if it is configured in the system + parsers.append(RegexOutputParser("system", system.reporting["regex"])) + # Add Base64 parser for applications + parsers.append(Base64OutputParser("application")) + self.reporter = Reporter(parsers=parsers) + else: + self.reporter = None # No reporter needed. + + self.param_resolver = ParamResolver(self) + self._resolved_build_dir: Optional[Path] = None + + @property + def is_deploy_needed(self) -> bool: + """Check if application requires data deployment.""" + if isinstance(self.app, Application): + return ( + len(self.app.get_deploy_data()) > 0 or len(self.custom_deploy_data) > 0 + ) + return False + + @property + def is_locking_required(self) -> bool: + """Return true if any form of locking required.""" + return not self._disable_locking() and ( + self.app.lock or (self.system is not None and self.system.lock) + ) + + @property + def is_build_required(self) -> bool: + """Return true if application build required.""" + return "build" in self.app.commands + + @property + def is_unique_build_dir_required(self) -> bool: + """Return true if unique build dir required.""" + return self.execution_params.get("unique_build_dir", False) + + def build_dir(self) -> Path: + """Return resolved application build dir.""" + if self._resolved_build_dir is not None: + return self._resolved_build_dir + + if ( + not isinstance(self.app.config_location, Path) + or not self.app.config_location.is_dir() + ): + raise ConfigurationException( + "Application {} has wrong config location".format(self.app.name) + ) + + _build_dir = self.app.build_dir + if _build_dir: + _build_dir = resolve_all_parameters(_build_dir, self.param_resolver) + + if not _build_dir: + raise ConfigurationException( + "No build directory defined for the app {}".format(self.app.name) + ) + + if self.is_unique_build_dir_required: + random_suffix = "".join( + random.choices(string.ascii_lowercase + string.digits, k=7) + ) + _build_dir = "{}_{}".format(_build_dir, random_suffix) + + self._resolved_build_dir = self.app.config_location / _build_dir + return self._resolved_build_dir + + def _disable_locking(self) -> bool: + """Return true if locking should be disabled.""" + return self.execution_params.get("disable_locking", False) + + +class ParamResolver: + """Parameter resolver.""" + + def __init__(self, context: ExecutionContext): + """Init parameter resolver.""" + self.ctx = context + + @staticmethod + def resolve_user_params( + cmd_name: Optional[str], + index_or_alias: str, + resolved_params: Optional[List[Tuple[Optional[str], Param]]], + ) -> str: + """Resolve user params.""" + if not cmd_name or resolved_params is None: + raise ConfigurationException("Unable to resolve user params") + + param_value: Optional[str] = None + param: Optional[Param] = None + + if index_or_alias.isnumeric(): + i = int(index_or_alias) + if i not in range(len(resolved_params)): + raise ConfigurationException( + "Invalid index {} for user params of command {}".format(i, cmd_name) + ) + param_value, param = resolved_params[i] + else: + for val, par in resolved_params: + if par.alias == index_or_alias: + param_value, param = val, par + break + + if param is None: + raise ConfigurationException( + "No user parameter for command '{}' with alias '{}'.".format( + cmd_name, index_or_alias + ) + ) + + if param_value: + # We need to handle to cases of parameters here: + # 1) Optional parameters (non-positional with a name and value) + # 2) Positional parameters (value only, no name needed) + # Default to empty strings for positional arguments + param_name = "" + separator = "" + if param.name is not None: + # A valid param name means we have an optional/non-positional argument: + # The separator is an empty string in case the param_name + # has an equal sign as we have to honour it. + # If the parameter doesn't end with an equal sign then a + # space character is injected to split the parameter name + # and its value + param_name = param.name + separator = "" if param.name.endswith("=") else " " + + return "{param_name}{separator}{param_value}".format( + param_name=param_name, + separator=separator, + param_value=param_value, + ) + + if param.name is None: + raise ConfigurationException( + "Missing user parameter with alias '{}' for command '{}'.".format( + index_or_alias, cmd_name + ) + ) + + return param.name # flag: just return the parameter name + + def resolve_commands_and_params( + self, backend_type: str, cmd_name: str, return_params: bool, index_or_alias: str + ) -> str: + """Resolve command or command's param value.""" + if backend_type == "system": + backend = cast(Backend, self.ctx.system) + backend_params = self.ctx.system_params + else: # Application or Tool backend + backend = cast(Backend, self.ctx.app) + backend_params = self.ctx.app_params + + if cmd_name not in backend.commands: + raise ConfigurationException("Command {} not found".format(cmd_name)) + + if return_params: + params = backend.resolved_parameters(cmd_name, backend_params) + if index_or_alias.isnumeric(): + i = int(index_or_alias) + if i not in range(len(params)): + raise ConfigurationException( + "Invalid parameter index {} for command {}".format(i, cmd_name) + ) + + param_value = params[i][0] + else: + param_value = None + for value, param in params: + if param.alias == index_or_alias: + param_value = value + break + + if not param_value: + raise ConfigurationException( + ( + "No value for parameter with index or alias {} of command {}" + ).format(index_or_alias, cmd_name) + ) + return param_value + + if not index_or_alias.isnumeric(): + raise ConfigurationException("Bad command index {}".format(index_or_alias)) + + i = int(index_or_alias) + commands = backend.build_command(cmd_name, backend_params, self.param_resolver) + if i not in range(len(commands)): + raise ConfigurationException( + "Invalid index {} for command {}".format(i, cmd_name) + ) + + return commands[i] + + def resolve_variables(self, backend_type: str, var_name: str) -> str: + """Resolve variable value.""" + if backend_type == "system": + backend = cast(Backend, self.ctx.system) + else: # Application or Tool backend + backend = cast(Backend, self.ctx.app) + + if var_name not in backend.variables: + raise ConfigurationException("Unknown variable {}".format(var_name)) + + return backend.variables[var_name] + + def param_matcher( + self, + param_name: str, + cmd_name: Optional[str], + resolved_params: Optional[List[Tuple[Optional[str], Param]]], + ) -> str: + """Regexp to resolve a param from the param_name.""" + # this pattern supports parameter names like "application.commands.run:0" and + # "system.commands.run.params:0" + # Note: 'software' is included for backward compatibility. + commands_and_params_match = re.match( + r"(?Papplication|software|tool|system)[.]commands[.]" + r"(?P\w+)" + r"(?P[.]params|)[:]" + r"(?P\w+)", + param_name, + ) + + if commands_and_params_match: + backend_type, cmd_name, return_params, index_or_alias = ( + commands_and_params_match["type"], + commands_and_params_match["name"], + commands_and_params_match["params"], + commands_and_params_match["index_or_alias"], + ) + return self.resolve_commands_and_params( + backend_type, cmd_name, bool(return_params), index_or_alias + ) + + # Note: 'software' is included for backward compatibility. + variables_match = re.match( + r"(?Papplication|software|tool|system)[.]variables:(?P\w+)", + param_name, + ) + if variables_match: + backend_type, var_name = ( + variables_match["type"], + variables_match["var_name"], + ) + return self.resolve_variables(backend_type, var_name) + + user_params_match = re.match(r"user_params:(?P\w+)", param_name) + if user_params_match: + index_or_alias = user_params_match["index_or_alias"] + return self.resolve_user_params(cmd_name, index_or_alias, resolved_params) + + raise ConfigurationException( + "Unable to resolve parameter {}".format(param_name) + ) + + def param_resolver( + self, + param_name: str, + cmd_name: Optional[str] = None, + resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, + ) -> str: + """Resolve parameter value based on current execution context.""" + # Note: 'software.*' is included for backward compatibility. + resolved_param = None + if param_name in ["application.name", "tool.name", "software.name"]: + resolved_param = self.ctx.app.name + elif param_name in [ + "application.description", + "tool.description", + "software.description", + ]: + resolved_param = self.ctx.app.description + elif self.ctx.app.config_location and ( + param_name + in ["application.config_dir", "tool.config_dir", "software.config_dir"] + ): + resolved_param = str(self.ctx.app.config_location.absolute()) + elif self.ctx.app.build_dir and ( + param_name + in ["application.build_dir", "tool.build_dir", "software.build_dir"] + ): + resolved_param = str(self.ctx.build_dir().absolute()) + elif self.ctx.system is not None: + if param_name == "system.name": + resolved_param = self.ctx.system.name + elif param_name == "system.description": + resolved_param = self.ctx.system.description + elif param_name == "system.config_dir" and self.ctx.system.config_location: + resolved_param = str(self.ctx.system.config_location.absolute()) + + if not resolved_param: + resolved_param = self.param_matcher(param_name, cmd_name, resolved_params) + return resolved_param + + def __call__( + self, + param_name: str, + cmd_name: Optional[str] = None, + resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, + ) -> str: + """Resolve provided parameter.""" + return self.param_resolver(param_name, cmd_name, resolved_params) + + +class Reporter: + """Report metrics from the simulation output.""" + + def __init__(self, parsers: Optional[List[OutputParser]] = None) -> None: + """Create an empty reporter (i.e. no parsers registered).""" + self.parsers: List[OutputParser] = parsers if parsers is not None else [] + self._report: Dict[str, Any] = defaultdict(lambda: defaultdict(dict)) + + def parse(self, output: bytearray) -> None: + """Parse output and append parsed metrics to internal report dict.""" + for parser in self.parsers: + # Merge metrics from different parsers (do not overwrite) + self._report[parser.name]["metrics"].update(parser(output)) + + def get_filtered_output(self, output: bytearray) -> bytearray: + """Filter the output according to each parser.""" + for parser in self.parsers: + output = parser.filter_out_parsed_content(output) + return output + + def report(self, ctx: ExecutionContext) -> Dict[str, Any]: + """Add static simulation info to parsed data and return the report.""" + report: Dict[str, Any] = defaultdict(dict) + # Add static simulation info + report.update(self._static_info(ctx)) + # Add metrics parsed from the output + for key, val in self._report.items(): + report[key].update(val) + return report + + @staticmethod + def save(report: Dict[str, Any], report_file: Path) -> None: + """Save the report to a JSON file.""" + with open(report_file, "w", encoding="utf-8") as file: + json.dump(report, file, indent=4) + + @staticmethod + def _compute_all_params(cli_params: List[str], backend: Backend) -> Dict[str, str]: + """ + Build a dict of all parameters, {name:value}. + + Param values taken from command line if specified, defaults otherwise. + """ + # map of params passed from the cli ["p1=v1","p2=v2"] -> {"p1":"v1", "p2":"v2"} + app_params_map = dict(parse_raw_parameter(expr) for expr in cli_params) + + # a map of params declared in the application, with values taken from the CLI, + # defaults otherwise + all_params = { + (p.alias or p.name): app_params_map.get( + cast(str, p.name), cast(str, p.default_value) + ) + for cmd in backend.commands.values() + for p in cmd.params + } + return cast(Dict[str, str], all_params) + + @staticmethod + def _static_info(ctx: ExecutionContext) -> Dict[str, Any]: + """Extract static simulation information from the context.""" + if ctx.system is None: + raise ValueError("No system available to report.") + + info = { + "system": { + "name": ctx.system.name, + "params": Reporter._compute_all_params(ctx.system_params, ctx.system), + }, + "application": { + "name": ctx.app.name, + "params": Reporter._compute_all_params(ctx.app_params, ctx.app), + }, + } + return info + + +def validate_parameters( + backend: Backend, command_names: List[str], params: List[str] +) -> None: + """Check parameters passed to backend.""" + for param in params: + acceptable = any( + backend.validate_parameter(command_name, param) + for command_name in command_names + if command_name in backend.commands + ) + + if not acceptable: + backend_type = "System" if isinstance(backend, System) else "Application" + raise ValueError( + "{} parameter '{}' not valid for command '{}'".format( + backend_type, param, " or ".join(command_names) + ) + ) + + +def get_application_by_name_and_system( + application_name: str, system_name: str +) -> Application: + """Get application.""" + applications = get_application(application_name, system_name) + if not applications: + raise ValueError( + "Application '{}' doesn't support the system '{}'".format( + application_name, system_name + ) + ) + + if len(applications) != 1: + raise ValueError( + "Error during getting application {} for the system {}".format( + application_name, system_name + ) + ) + + return applications[0] + + +def get_application_and_system( + application_name: str, system_name: str +) -> Tuple[Application, System]: + """Return application and system by provided names.""" + system = get_system(system_name) + if not system: + raise ValueError("System {} is not found".format(system_name)) + + application = get_application_by_name_and_system(application_name, system_name) + + return application, system + + +def execute_application_command( # pylint: disable=too-many-arguments + command_name: str, + application_name: str, + application_params: List[str], + system_name: str, + system_params: List[str], + custom_deploy_data: List[DataPaths], +) -> None: + """Execute application command. + + .. deprecated:: 21.12 + """ + warnings.warn( + "Use 'run_application()' instead. Use of 'execute_application_command()' is " + "deprecated and might be removed in a future release.", + DeprecationWarning, + ) + + if command_name not in ["build", "run"]: + raise ConfigurationException("Unsupported command {}".format(command_name)) + + application, system = get_application_and_system(application_name, system_name) + validate_parameters(application, [command_name], application_params) + validate_parameters(system, [command_name], system_params) + + ctx = ExecutionContext( + app=application, + app_params=application_params, + system=system, + system_params=system_params, + custom_deploy_data=custom_deploy_data, + ) + + if command_name == "run": + execute_application_command_run(ctx) + else: + execute_application_command_build(ctx) + + +# pylint: disable=too-many-arguments +def run_application( + application_name: str, + application_params: List[str], + system_name: str, + system_params: List[str], + custom_deploy_data: List[DataPaths], + report_file: Optional[Path] = None, +) -> None: + """Run application on the provided system.""" + application, system = get_application_and_system(application_name, system_name) + validate_parameters(application, ["build", "run"], application_params) + validate_parameters(system, ["build", "run"], system_params) + + execution_params = ExecutionParams() + if isinstance(system, StandaloneSystem): + execution_params["disable_locking"] = True + execution_params["unique_build_dir"] = True + + ctx = ExecutionContext( + app=application, + app_params=application_params, + system=system, + system_params=system_params, + custom_deploy_data=custom_deploy_data, + execution_params=execution_params, + report_file=report_file, + ) + + with build_dir_manager(ctx): + if ctx.is_build_required: + execute_application_command_build(ctx) + + execute_application_command_run(ctx) + + +def execute_application_command_build(ctx: ExecutionContext) -> None: + """Execute application command 'build'.""" + with ExitStack() as context_stack: + for manager in get_context_managers("build", ctx): + context_stack.enter_context(manager(ctx)) + + build_dir = ctx.build_dir() + recreate_directory(build_dir) + + build_commands = ctx.app.build_command( + "build", ctx.app_params, ctx.param_resolver + ) + execute_commands_locally(build_commands, build_dir) + + +def execute_commands_locally(commands: List[str], cwd: Path) -> None: + """Execute list of commands locally.""" + for command in commands: + print("Running: {}".format(command)) + run_and_wait( + command, cwd, terminate_on_error=True, out=sys.stdout, err=sys.stderr + ) + + +def execute_application_command_run(ctx: ExecutionContext) -> None: + """Execute application command.""" + assert ctx.system is not None, "System must be provided." + if ctx.is_deploy_needed and not ctx.system.supports_deploy: + raise ConfigurationException( + "System {} does not support data deploy".format(ctx.system.name) + ) + + with ExitStack() as context_stack: + for manager in get_context_managers("run", ctx): + context_stack.enter_context(manager(ctx)) + + print("Generating commands to execute") + commands_to_run = build_run_commands(ctx) + + if ctx.system.connectable: + establish_connection(ctx) + + if ctx.system.supports_deploy: + deploy_data(ctx) + + for command in commands_to_run: + print("Running: {}".format(command)) + exit_code, std_output, std_err = ctx.system.run(command) + + if exit_code != 0: + print("Application exited with exit code {}".format(exit_code)) + + if ctx.reporter: + ctx.reporter.parse(std_output) + std_output = ctx.reporter.get_filtered_output(std_output) + + print(std_output.decode("utf8"), end="") + print(std_err.decode("utf8"), end="") + + if ctx.reporter: + report = ctx.reporter.report(ctx) + ctx.reporter.save(report, cast(Path, ctx.report_file)) + + +def establish_connection( + ctx: ExecutionContext, retries: int = 90, interval: float = 15.0 +) -> None: + """Establish connection with the system.""" + assert ctx.system is not None, "System is required." + host, port = ctx.system.connection_details() + print( + "Trying to establish connection with '{}:{}' - " + "{} retries every {} seconds ".format(host, port, retries, interval), + end="", + ) + + try: + for _ in range(retries): + print(".", end="", flush=True) + + if ctx.system.establish_connection(): + break + + if isinstance(ctx.system, ControlledSystem) and not ctx.system.is_running(): + print( + "\n\n---------- {} execution failed ----------".format( + ctx.system.name + ) + ) + stdout, stderr = ctx.system.get_output() + print(stdout) + print(stderr) + + raise Exception("System is not running") + + wait(interval) + else: + raise ConnectionException("Couldn't connect to '{}:{}'.".format(host, port)) + finally: + print() + + +def wait(interval: float) -> None: + """Wait for a period of time.""" + time.sleep(interval) + + +def deploy_data(ctx: ExecutionContext) -> None: + """Deploy data to the system.""" + if isinstance(ctx.app, Application): + # Only application can deploy data (tools can not) + assert ctx.system is not None, "System is required." + for item in itertools.chain(ctx.app.get_deploy_data(), ctx.custom_deploy_data): + print("Deploying {} onto {}".format(item.src, item.dst)) + ctx.system.deploy(item.src, item.dst) + + +def build_run_commands(ctx: ExecutionContext) -> List[str]: + """Build commands to run application.""" + if isinstance(ctx.system, StandaloneSystem): + return ctx.system.build_command("run", ctx.system_params, ctx.param_resolver) + + return ctx.app.build_command("run", ctx.app_params, ctx.param_resolver) + + +@contextmanager +def controlled_system_manager(ctx: ExecutionContext) -> Generator[None, None, None]: + """Context manager used for system initialisation before run.""" + system = cast(ControlledSystem, ctx.system) + commands = system.build_command("run", ctx.system_params, ctx.param_resolver) + pid_file_path: Optional[Path] = None + if ctx.is_locking_required: + file_lock_path = get_file_lock_path(ctx) + pid_file_path = file_lock_path.parent / "{}.pid".format(file_lock_path.stem) + + system.start(commands, ctx.is_locking_required, pid_file_path) + try: + yield + finally: + print("Shutting down sequence...") + print("Stopping {}... (It could take few seconds)".format(system.name)) + system.stop(wait=True) + print("{} stopped successfully.".format(system.name)) + + +@contextmanager +def lock_execution_manager(ctx: ExecutionContext) -> Generator[None, None, None]: + """Lock execution manager.""" + file_lock_path = get_file_lock_path(ctx) + file_lock = FileLock(str(file_lock_path)) + + try: + file_lock.acquire(timeout=1) + except Timeout as error: + raise AnotherInstanceIsRunningException() from error + + try: + yield + finally: + file_lock.release() + + +def get_file_lock_path(ctx: ExecutionContext, lock_dir: Path = Path("/tmp")) -> Path: + """Get file lock path.""" + lock_modules = [] + if ctx.app.lock: + lock_modules.append(ctx.app.name) + if ctx.system is not None and ctx.system.lock: + lock_modules.append(ctx.system.name) + lock_filename = "" + if lock_modules: + lock_filename = "_".join(["middleware"] + lock_modules) + ".lock" + + if lock_filename: + lock_filename = resolve_all_parameters(lock_filename, ctx.param_resolver) + lock_filename = valid_for_filename(lock_filename) + + if not lock_filename: + raise ConfigurationException("No filename for lock provided") + + if not isinstance(lock_dir, Path) or not lock_dir.is_dir(): + raise ConfigurationException( + "Invalid directory {} for lock files provided".format(lock_dir) + ) + + return lock_dir / lock_filename + + +@contextmanager +def build_dir_manager(ctx: ExecutionContext) -> Generator[None, None, None]: + """Build directory manager.""" + try: + yield + finally: + if ( + ctx.is_build_required + and ctx.is_unique_build_dir_required + and ctx.build_dir().is_dir() + ): + remove_directory(ctx.build_dir()) + + +def get_context_managers( + command_name: str, ctx: ExecutionContext +) -> Sequence[Callable[[ExecutionContext], ContextManager[None]]]: + """Get context manager for the system.""" + managers = [] + + if ctx.is_locking_required: + managers.append(lock_execution_manager) + + if command_name == "run": + if isinstance(ctx.system, ControlledSystem): + managers.append(controlled_system_manager) + + return managers + + +def get_tool_by_system(tool_name: str, system_name: Optional[str]) -> Tool: + """Return tool (optionally by provided system name.""" + tools = get_tool(tool_name, system_name) + if not tools: + raise ConfigurationException( + "Tool '{}' not found or doesn't support the system '{}'".format( + tool_name, system_name + ) + ) + if len(tools) != 1: + raise ConfigurationException( + "Please specify the system for tool {}.".format(tool_name) + ) + tool = tools[0] + + return tool + + +def execute_tool_command( + tool_name: str, + tool_params: List[str], + system_name: Optional[str] = None, +) -> None: + """Execute the tool command locally calling the 'run' command.""" + tool = get_tool_by_system(tool_name, system_name) + ctx = ExecutionContext( + app=tool, app_params=tool_params, system=None, system_params=[] + ) + commands = tool.build_command("run", tool_params, ctx.param_resolver) + + execute_commands_locally(commands, Path.cwd()) diff --git a/src/aiet/backend/output_parser.py b/src/aiet/backend/output_parser.py new file mode 100644 index 0000000..111772a --- /dev/null +++ b/src/aiet/backend/output_parser.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Definition of output parsers (including base class OutputParser).""" +import base64 +import json +import re +from abc import ABC +from abc import abstractmethod +from typing import Any +from typing import Dict +from typing import Union + + +class OutputParser(ABC): + """Abstract base class for output parsers.""" + + def __init__(self, name: str) -> None: + """Set up the name of the parser.""" + super().__init__() + self.name = name + + @abstractmethod + def __call__(self, output: bytearray) -> Dict[str, Any]: + """Parse the output and return a map of names to metrics.""" + return {} + + # pylint: disable=no-self-use + def filter_out_parsed_content(self, output: bytearray) -> bytearray: + """ + Filter out the parsed content from the output. + + Does nothing by default. Can be overridden in subclasses. + """ + return output + + +class RegexOutputParser(OutputParser): + """Parser of standard output data using regular expressions.""" + + _TYPE_MAP = {"str": str, "float": float, "int": int} + + def __init__( + self, + name: str, + regex_config: Dict[str, Dict[str, str]], + ) -> None: + """ + Set up the parser with the regular expressions. + + The regex_config is mapping from a name to a dict with keys 'pattern' + and 'type': + - The 'pattern' holds the regular expression that must contain exactly + one capturing parenthesis + - The 'type' can be one of ['str', 'float', 'int']. + + Example: + ``` + {"Metric1": {"pattern": ".*= *(.*)", "type": "str"}} + ``` + + The different regular expressions from the config are combined using + non-capturing parenthesis, i.e. regular expressions must not overlap + if more than one match per line is expected. + """ + super().__init__(name) + + self._verify_config(regex_config) + self._regex_cfg = regex_config + + # Compile regular expression to match in the output + self._regex = re.compile( + "|".join("(?:{0})".format(x["pattern"]) for x in self._regex_cfg.values()) + ) + + def __call__(self, output: bytearray) -> Dict[str, Union[str, float, int]]: + """ + Parse the output and return a map of names to metrics. + + Example: + Assuming a regex_config as used as example in `__init__()` and the + following output: + ``` + Simulation finished: + SIMULATION_STATUS = SUCCESS + Simulation DONE + ``` + Then calling the parser should return the following dict: + ``` + { + "Metric1": "SUCCESS" + } + ``` + """ + metrics = {} + output_str = output.decode("utf-8") + results = self._regex.findall(output_str) + for line_result in results: + for idx, (name, cfg) in enumerate(self._regex_cfg.items()): + # The result(s) returned by findall() are either a single string + # or a tuple (depending on the number of groups etc.) + result = ( + line_result if isinstance(line_result, str) else line_result[idx] + ) + if result: + mapped_result = self._TYPE_MAP[cfg["type"]](result) + metrics[name] = mapped_result + return metrics + + def _verify_config(self, regex_config: Dict[str, Dict[str, str]]) -> None: + """Make sure we have a valid regex_config. + + I.e. + - Exactly one capturing parenthesis per pattern + - Correct types + """ + for name, cfg in regex_config.items(): + # Check that there is one capturing group defined in the pattern. + regex = re.compile(cfg["pattern"]) + if regex.groups != 1: + raise ValueError( + f"Pattern for metric '{name}' must have exactly one " + f"capturing parenthesis, but it has {regex.groups}." + ) + # Check if type is supported + if not cfg["type"] in self._TYPE_MAP: + raise TypeError( + f"Type '{cfg['type']}' for metric '{name}' is not " + f"supported. Choose from: {list(self._TYPE_MAP.keys())}." + ) + + +class Base64OutputParser(OutputParser): + """ + Parser to extract base64-encoded JSON from tagged standard output. + + Example of the tagged output: + ``` + # Encoded JSON: {"test": 1234} + eyJ0ZXN0IjogMTIzNH0 + ``` + """ + + TAG_NAME = "metrics" + + def __init__(self, name: str) -> None: + """Set up the regular expression to extract tagged strings.""" + super().__init__(name) + self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)") + + def __call__(self, output: bytearray) -> Dict[str, Any]: + """ + Parse the output and return a map of index (as string) to decoded JSON. + + Example: + Using the tagged output from the class docs the parser should return + the following dict: + ``` + { + "0": {"test": 1234} + } + ``` + """ + metrics = {} + output_str = output.decode("utf-8") + results = self._regex.findall(output_str) + for idx, result_base64 in enumerate(results): + result_json = base64.b64decode(result_base64, validate=True) + result = json.loads(result_json) + metrics[str(idx)] = result + + return metrics + + def filter_out_parsed_content(self, output: bytearray) -> bytearray: + """Filter out base64-encoded content from the output.""" + output_str = self._regex.sub("", output.decode("utf-8")) + return bytearray(output_str.encode("utf-8")) diff --git a/src/aiet/backend/protocol.py b/src/aiet/backend/protocol.py new file mode 100644 index 0000000..c621436 --- /dev/null +++ b/src/aiet/backend/protocol.py @@ -0,0 +1,325 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain protocol related classes and functions.""" +from abc import ABC +from abc import abstractmethod +from contextlib import closing +from pathlib import Path +from typing import Any +from typing import cast +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import Union + +import paramiko + +from aiet.backend.common import ConfigurationException +from aiet.backend.config import LocalProtocolConfig +from aiet.backend.config import SSHConfig +from aiet.utils.proc import run_and_wait + + +# Redirect all paramiko thread exceptions to a file otherwise these will be +# printed to stderr. +paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO) + + +class SSHConnectionException(Exception): + """SSH connection exception.""" + + +class SupportsClose(ABC): + """Class indicates support of close operation.""" + + @abstractmethod + def close(self) -> None: + """Close protocol session.""" + + +class SupportsDeploy(ABC): + """Class indicates support of deploy operation.""" + + @abstractmethod + def deploy(self, src: Path, dst: str, retry: bool = True) -> None: + """Abstract method for deploy data.""" + + +class SupportsConnection(ABC): + """Class indicates that protocol uses network connections.""" + + @abstractmethod + def establish_connection(self) -> bool: + """Establish connection with underlying system.""" + + @abstractmethod + def connection_details(self) -> Tuple[str, int]: + """Return connection details (host, port).""" + + +class Protocol(ABC): + """Abstract class for representing the protocol.""" + + def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: + """Initialize the class using a dict.""" + self.__dict__.update(iterable, **kwargs) + self._validate() + + @abstractmethod + def _validate(self) -> None: + """Abstract method for config data validation.""" + + @abstractmethod + def run( + self, command: str, retry: bool = False + ) -> Tuple[int, bytearray, bytearray]: + """ + Abstract method for running commands. + + Returns a tuple: (exit_code, stdout, stderr) + """ + + +class CustomSFTPClient(paramiko.SFTPClient): + """Class for creating a custom sftp client.""" + + def put_dir(self, source: Path, target: str) -> None: + """Upload the source directory to the target path. + + The target directory needs to exists and the last directory of the + source will be created under the target with all its content. + """ + # Create the target directory + self._mkdir(target, ignore_existing=True) + # Create the last directory in the source on the target + self._mkdir("{}/{}".format(target, source.name), ignore_existing=True) + # Go through the whole content of source + for item in sorted(source.glob("**/*")): + relative_path = item.relative_to(source.parent) + remote_target = target / relative_path + if item.is_file(): + self.put(str(item), str(remote_target)) + else: + self._mkdir(str(remote_target), ignore_existing=True) + + def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None: + """Extend mkdir functionality. + + This version adds an option to not fail if the folder exists. + """ + try: + super().mkdir(path, mode) + except IOError as error: + if ignore_existing: + pass + else: + raise error + + +class LocalProtocol(Protocol): + """Class for local protocol.""" + + protocol: str + cwd: Path + + def run( + self, command: str, retry: bool = False + ) -> Tuple[int, bytearray, bytearray]: + """ + Run command locally. + + Returns a tuple: (exit_code, stdout, stderr) + """ + if not isinstance(self.cwd, Path) or not self.cwd.is_dir(): + raise ConfigurationException("Wrong working directory {}".format(self.cwd)) + + stdout = bytearray() + stderr = bytearray() + + return run_and_wait( + command, self.cwd, terminate_on_error=True, out=stdout, err=stderr + ) + + def _validate(self) -> None: + """Validate protocol configuration.""" + assert hasattr(self, "protocol") and self.protocol == "local" + assert hasattr(self, "cwd") + + +class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection): + """Class for SSH protocol.""" + + protocol: str + username: str + password: str + hostname: str + port: int + + def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: + """Initialize the class using a dict.""" + super().__init__(iterable, **kwargs) + # Internal state to store if the system is connectable. It will be set + # to true at the first connection instance + self.client: Optional[paramiko.client.SSHClient] = None + self.port = int(self.port) + + def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: + """ + Run command over SSH. + + Returns a tuple: (exit_code, stdout, stderr) + """ + transport = self._get_transport() + with closing(transport.open_session()) as channel: + # Enable shell's .profile settings and execute command + channel.exec_command("bash -l -c '{}'".format(command)) + exit_status = -1 + stdout = bytearray() + stderr = bytearray() + while True: + if channel.exit_status_ready(): + exit_status = channel.recv_exit_status() + # Call it one last time to read any leftover in the channel + self._recv_stdout_err(channel, stdout, stderr) + break + self._recv_stdout_err(channel, stdout, stderr) + + return exit_status, stdout, stderr + + def deploy(self, src: Path, dst: str, retry: bool = True) -> None: + """Deploy src to remote dst over SSH. + + src and dst should be path to a file or directory. + """ + transport = self._get_transport() + sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport)) + + with closing(sftp): + if src.is_dir(): + sftp.put_dir(src, dst) + elif src.is_file(): + sftp.put(str(src), dst) + else: + raise Exception("Deploy error: file type not supported") + + # After the deployment of files, sync the remote filesystem to flush + # buffers to hard disk + self.run("sync") + + def close(self) -> None: + """Close protocol session.""" + if self.client is not None: + print("Try syncing remote file system...") + # Before stopping the system, we try to run sync to make sure all + # data are flushed on disk. + self.run("sync", retry=False) + self._close_client(self.client) + + def establish_connection(self) -> bool: + """Establish connection with underlying system.""" + if self.client is not None: + return True + + self.client = self._connect() + return self.client is not None + + def _get_transport(self) -> paramiko.transport.Transport: + """Get transport.""" + self.establish_connection() + + if self.client is None: + raise SSHConnectionException( + "Couldn't connect to '{}:{}'.".format(self.hostname, self.port) + ) + + transport = self.client.get_transport() + if not transport: + raise Exception("Unable to get transport") + + return transport + + def connection_details(self) -> Tuple[str, int]: + """Return connection details of underlying system.""" + return (self.hostname, self.port) + + def _connect(self) -> Optional[paramiko.client.SSHClient]: + """Try to establish connection.""" + client: Optional[paramiko.client.SSHClient] = None + try: + client = paramiko.client.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect( + self.hostname, + self.port, + self.username, + self.password, + # next parameters should be set to False to disable authentication + # using ssh keys + allow_agent=False, + look_for_keys=False, + ) + return client + except ( + # OSError raised on first attempt to connect when running inside Docker + OSError, + paramiko.ssh_exception.NoValidConnectionsError, + paramiko.ssh_exception.SSHException, + ): + # even if connection is not established socket could be still + # open, it should be closed + self._close_client(client) + + return None + + @staticmethod + def _close_client(client: Optional[paramiko.client.SSHClient]) -> None: + """Close ssh client.""" + try: + if client is not None: + client.close() + except Exception: # pylint: disable=broad-except + pass + + @classmethod + def _recv_stdout_err( + cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray + ) -> None: + """Read from channel to stdout/stder.""" + chunk_size = 512 + if channel.recv_ready(): + stdout_chunk = channel.recv(chunk_size) + stdout.extend(stdout_chunk) + if channel.recv_stderr_ready(): + stderr_chunk = channel.recv_stderr(chunk_size) + stderr.extend(stderr_chunk) + + def _validate(self) -> None: + """Check if there are all the info for establishing the connection.""" + assert hasattr(self, "protocol") and self.protocol == "ssh" + assert hasattr(self, "username") + assert hasattr(self, "password") + assert hasattr(self, "hostname") + assert hasattr(self, "port") + + +class ProtocolFactory: + """Factory class to return the appropriate Protocol class.""" + + @staticmethod + def get_protocol( + config: Optional[Union[SSHConfig, LocalProtocolConfig]], + **kwargs: Union[str, Path, None] + ) -> Union[SSHProtocol, LocalProtocol]: + """Return the right protocol instance based on the config.""" + if not config: + raise ValueError("No protocol config provided") + + protocol = config["protocol"] + if protocol == "ssh": + return SSHProtocol(config) + + if protocol == "local": + cwd = kwargs.get("cwd") + return LocalProtocol(config, cwd=cwd) + + raise ValueError("Protocol not supported: '{}'".format(protocol)) diff --git a/src/aiet/backend/source.py b/src/aiet/backend/source.py new file mode 100644 index 0000000..dec175a --- /dev/null +++ b/src/aiet/backend/source.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain source related classes and functions.""" +import os +import shutil +import tarfile +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from tarfile import TarFile +from typing import Optional +from typing import Union + +from aiet.backend.common import AIET_CONFIG_FILE +from aiet.backend.common import ConfigurationException +from aiet.backend.common import get_backend_config +from aiet.backend.common import is_backend_directory +from aiet.backend.common import load_config +from aiet.backend.config import BackendConfig +from aiet.utils.fs import copy_directory_content + + +class Source(ABC): + """Source class.""" + + @abstractmethod + def name(self) -> Optional[str]: + """Get source name.""" + + @abstractmethod + def config(self) -> Optional[BackendConfig]: + """Get configuration file content.""" + + @abstractmethod + def install_into(self, destination: Path) -> None: + """Install source into destination directory.""" + + @abstractmethod + def create_destination(self) -> bool: + """Return True if destination folder should be created before installation.""" + + +class DirectorySource(Source): + """DirectorySource class.""" + + def __init__(self, directory_path: Path) -> None: + """Create the DirectorySource instance.""" + assert isinstance(directory_path, Path) + self.directory_path = directory_path + + def name(self) -> str: + """Return name of source.""" + return self.directory_path.name + + def config(self) -> Optional[BackendConfig]: + """Return configuration file content.""" + if not is_backend_directory(self.directory_path): + raise ConfigurationException("No configuration file found") + + config_file = get_backend_config(self.directory_path) + return load_config(config_file) + + def install_into(self, destination: Path) -> None: + """Install source into destination directory.""" + if not destination.is_dir(): + raise ConfigurationException("Wrong destination {}".format(destination)) + + if not self.directory_path.is_dir(): + raise ConfigurationException( + "Directory {} does not exist".format(self.directory_path) + ) + + copy_directory_content(self.directory_path, destination) + + def create_destination(self) -> bool: + """Return True if destination folder should be created before installation.""" + return True + + +class TarArchiveSource(Source): + """TarArchiveSource class.""" + + def __init__(self, archive_path: Path) -> None: + """Create the TarArchiveSource class.""" + assert isinstance(archive_path, Path) + self.archive_path = archive_path + self._config: Optional[BackendConfig] = None + self._has_top_level_folder: Optional[bool] = None + self._name: Optional[str] = None + + def _read_archive_content(self) -> None: + """Read various information about archive.""" + # get source name from archive name (everything without extensions) + extensions = "".join(self.archive_path.suffixes) + self._name = self.archive_path.name.rstrip(extensions) + + if not self.archive_path.exists(): + return + + with self._open(self.archive_path) as archive: + try: + config_entry = archive.getmember(AIET_CONFIG_FILE) + self._has_top_level_folder = False + except KeyError as error_no_config: + try: + archive_entries = archive.getnames() + entries_common_prefix = os.path.commonprefix(archive_entries) + top_level_dir = entries_common_prefix.rstrip("/") + + if not top_level_dir: + raise RuntimeError( + "Archive has no top level directory" + ) from error_no_config + + config_path = "{}/{}".format(top_level_dir, AIET_CONFIG_FILE) + + config_entry = archive.getmember(config_path) + self._has_top_level_folder = True + self._name = top_level_dir + except (KeyError, RuntimeError) as error_no_root_dir_or_config: + raise ConfigurationException( + "No configuration file found" + ) from error_no_root_dir_or_config + + content = archive.extractfile(config_entry) + self._config = load_config(content) + + def config(self) -> Optional[BackendConfig]: + """Return configuration file content.""" + if self._config is None: + self._read_archive_content() + + return self._config + + def name(self) -> Optional[str]: + """Return name of the source.""" + if self._name is None: + self._read_archive_content() + + return self._name + + def create_destination(self) -> bool: + """Return True if destination folder must be created before installation.""" + if self._has_top_level_folder is None: + self._read_archive_content() + + return not self._has_top_level_folder + + def install_into(self, destination: Path) -> None: + """Install source into destination directory.""" + if not destination.is_dir(): + raise ConfigurationException("Wrong destination {}".format(destination)) + + with self._open(self.archive_path) as archive: + archive.extractall(destination) + + def _open(self, archive_path: Path) -> TarFile: + """Open archive file.""" + if not archive_path.is_file(): + raise ConfigurationException("File {} does not exist".format(archive_path)) + + if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"): + mode = "r:gz" + else: + raise ConfigurationException( + "Unsupported archive type {}".format(archive_path) + ) + + # The returned TarFile object can be used as a context manager (using + # 'with') by the calling instance. + return tarfile.open( # pylint: disable=consider-using-with + self.archive_path, mode=mode + ) + + +def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]: + """Return appropriate source instance based on provided source path.""" + if source_path.is_file(): + return TarArchiveSource(source_path) + + if source_path.is_dir(): + return DirectorySource(source_path) + + raise ConfigurationException("Unable to read {}".format(source_path)) + + +def create_destination_and_install(source: Source, resource_path: Path) -> None: + """Create destination directory and install source. + + This function is used for actual installation of system/backend New + directory will be created inside :resource_path: if needed If for example + archive contains top level folder then no need to create new directory + """ + destination = resource_path + create_destination = source.create_destination() + + if create_destination: + name = source.name() + if not name: + raise ConfigurationException("Unable to get source name") + + destination = resource_path / name + destination.mkdir() + try: + source.install_into(destination) + except Exception as error: + if create_destination: + shutil.rmtree(destination) + raise error diff --git a/src/aiet/backend/system.py b/src/aiet/backend/system.py new file mode 100644 index 0000000..48f1bb1 --- /dev/null +++ b/src/aiet/backend/system.py @@ -0,0 +1,289 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""System backend module.""" +from pathlib import Path +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from aiet.backend.common import Backend +from aiet.backend.common import ConfigurationException +from aiet.backend.common import get_backend_configs +from aiet.backend.common import get_backend_directories +from aiet.backend.common import load_config +from aiet.backend.common import remove_backend +from aiet.backend.config import SystemConfig +from aiet.backend.controller import SystemController +from aiet.backend.controller import SystemControllerSingleInstance +from aiet.backend.protocol import ProtocolFactory +from aiet.backend.protocol import SupportsClose +from aiet.backend.protocol import SupportsConnection +from aiet.backend.protocol import SupportsDeploy +from aiet.backend.source import create_destination_and_install +from aiet.backend.source import get_source +from aiet.utils.fs import get_resources + + +def get_available_systems_directory_names() -> List[str]: + """Return a list of directory names for all avialable systems.""" + return [entry.name for entry in get_backend_directories("systems")] + + +def get_available_systems() -> List["System"]: + """Return a list with all available systems.""" + available_systems = [] + for config_json in get_backend_configs("systems"): + config_entries = cast(List[SystemConfig], (load_config(config_json))) + for config_entry in config_entries: + config_entry["config_location"] = config_json.parent.absolute() + system = load_system(config_entry) + available_systems.append(system) + + return sorted(available_systems, key=lambda system: system.name) + + +def get_system(system_name: str) -> Optional["System"]: + """Return a system instance with the same name passed as argument.""" + available_systems = get_available_systems() + for system in available_systems: + if system_name == system.name: + return system + return None + + +def install_system(source_path: Path) -> None: + """Install new system.""" + try: + source = get_source(source_path) + config = cast(List[SystemConfig], source.config()) + systems_to_install = [load_system(entry) for entry in config] + except Exception as error: + raise ConfigurationException("Unable to read system definition") from error + + if not systems_to_install: + raise ConfigurationException("No system definition found") + + available_systems = get_available_systems() + already_installed = [s for s in systems_to_install if s in available_systems] + if already_installed: + names = [system.name for system in already_installed] + raise ConfigurationException( + "Systems [{}] are already installed".format(",".join(names)) + ) + + create_destination_and_install(source, get_resources("systems")) + + +def remove_system(directory_name: str) -> None: + """Remove system.""" + remove_backend(directory_name, "systems") + + +class System(Backend): + """System class.""" + + def __init__(self, config: SystemConfig) -> None: + """Construct the System class using the dictionary passed.""" + super().__init__(config) + + self._setup_data_transfer(config) + self._setup_reporting(config) + + def _setup_data_transfer(self, config: SystemConfig) -> None: + data_transfer_config = config.get("data_transfer") + protocol = ProtocolFactory().get_protocol( + data_transfer_config, cwd=self.config_location + ) + self.protocol = protocol + + def _setup_reporting(self, config: SystemConfig) -> None: + self.reporting = config.get("reporting") + + def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: + """ + Run command on the system. + + Returns a tuple: (exit_code, stdout, stderr) + """ + return self.protocol.run(command, retry) + + def deploy(self, src: Path, dst: str, retry: bool = True) -> None: + """Deploy files to the system.""" + if isinstance(self.protocol, SupportsDeploy): + self.protocol.deploy(src, dst, retry) + + @property + def supports_deploy(self) -> bool: + """Check if protocol supports deploy operation.""" + return isinstance(self.protocol, SupportsDeploy) + + @property + def connectable(self) -> bool: + """Check if protocol supports connection.""" + return isinstance(self.protocol, SupportsConnection) + + def establish_connection(self) -> bool: + """Establish connection with the system.""" + if not isinstance(self.protocol, SupportsConnection): + raise ConfigurationException( + "System {} does not support connections".format(self.name) + ) + + return self.protocol.establish_connection() + + def connection_details(self) -> Tuple[str, int]: + """Return connection details.""" + if not isinstance(self.protocol, SupportsConnection): + raise ConfigurationException( + "System {} does not support connections".format(self.name) + ) + + return self.protocol.connection_details() + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, System): + return False + + return super().__eq__(other) and self.name == other.name + + def get_details(self) -> Dict[str, Any]: + """Return a dictionary with all relevant information of a System.""" + output = { + "type": "system", + "name": self.name, + "description": self.description, + "data_transfer_protocol": self.protocol.protocol, + "commands": self._get_command_details(), + "annotations": self.annotations, + } + + return output + + +class StandaloneSystem(System): + """StandaloneSystem class.""" + + +def get_controller( + single_instance: bool, pid_file_path: Optional[Path] = None +) -> SystemController: + """Get system controller.""" + if single_instance: + return SystemControllerSingleInstance(pid_file_path) + + return SystemController() + + +class ControlledSystem(System): + """ControlledSystem class.""" + + def __init__(self, config: SystemConfig): + """Construct the ControlledSystem class using the dictionary passed.""" + super().__init__(config) + self.controller: Optional[SystemController] = None + + def start( + self, + commands: List[str], + single_instance: bool = True, + pid_file_path: Optional[Path] = None, + ) -> None: + """Launch the system.""" + if ( + not isinstance(self.config_location, Path) + or not self.config_location.is_dir() + ): + raise ConfigurationException( + "System {} has wrong config location".format(self.name) + ) + + self.controller = get_controller(single_instance, pid_file_path) + self.controller.start(commands, self.config_location) + + def is_running(self) -> bool: + """Check if system is running.""" + if not self.controller: + return False + + return self.controller.is_running() + + def get_output(self) -> Tuple[str, str]: + """Return system output.""" + if not self.controller: + return "", "" + + return self.controller.get_output() + + def stop(self, wait: bool = False) -> None: + """Stop the system.""" + if not self.controller: + raise Exception("System has not been started") + + if isinstance(self.protocol, SupportsClose): + try: + self.protocol.close() + except Exception as error: # pylint: disable=broad-except + print(error) + self.controller.stop(wait) + + +def load_system(config: SystemConfig) -> Union[StandaloneSystem, ControlledSystem]: + """Load system based on it's execution type.""" + data_transfer = config.get("data_transfer", {}) + protocol = data_transfer.get("protocol") + populate_shared_params(config) + + if protocol == "ssh": + return ControlledSystem(config) + + if protocol == "local": + return StandaloneSystem(config) + + raise ConfigurationException( + "Unsupported execution type for protocol {}".format(protocol) + ) + + +def populate_shared_params(config: SystemConfig) -> None: + """Populate command parameters with shared parameters.""" + user_params = config.get("user_params") + if not user_params or "shared" not in user_params: + return + + shared_user_params = user_params["shared"] + if not shared_user_params: + return + + only_aliases = all(p.get("alias") for p in shared_user_params) + if not only_aliases: + raise ConfigurationException("All shared parameters should have aliases") + + commands = config.get("commands", {}) + for cmd_name in ["build", "run"]: + command = commands.get(cmd_name) + if command is None: + commands[cmd_name] = [] + cmd_user_params = user_params.get(cmd_name) + if not cmd_user_params: + cmd_user_params = shared_user_params + else: + only_aliases = all(p.get("alias") for p in cmd_user_params) + if not only_aliases: + raise ConfigurationException( + "All parameters for command {} should have aliases".format(cmd_name) + ) + merged_by_alias = { + **{p.get("alias"): p for p in shared_user_params}, + **{p.get("alias"): p for p in cmd_user_params}, + } + cmd_user_params = list(merged_by_alias.values()) + + user_params[cmd_name] = cmd_user_params + + config["commands"] = commands + del user_params["shared"] diff --git a/src/aiet/backend/tool.py b/src/aiet/backend/tool.py new file mode 100644 index 0000000..d643665 --- /dev/null +++ b/src/aiet/backend/tool.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tool backend module.""" +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional + +from aiet.backend.common import Backend +from aiet.backend.common import ConfigurationException +from aiet.backend.common import get_backend_configs +from aiet.backend.common import get_backend_directories +from aiet.backend.common import load_application_or_tool_configs +from aiet.backend.common import load_config +from aiet.backend.config import ExtendedToolConfig +from aiet.backend.config import ToolConfig + + +def get_available_tool_directory_names() -> List[str]: + """Return a list of directory names for all available tools.""" + return [entry.name for entry in get_backend_directories("tools")] + + +def get_available_tools() -> List["Tool"]: + """Return a list with all available tools.""" + available_tools = [] + for config_json in get_backend_configs("tools"): + config_entries = cast(List[ExtendedToolConfig], load_config(config_json)) + for config_entry in config_entries: + config_entry["config_location"] = config_json.parent.absolute() + tools = load_tools(config_entry) + available_tools += tools + + return sorted(available_tools, key=lambda tool: tool.name) + + +def get_tool(tool_name: str, system_name: Optional[str] = None) -> List["Tool"]: + """Return a tool instance with the same name passed as argument.""" + return [ + tool + for tool in get_available_tools() + if tool.name == tool_name and (not system_name or tool.can_run_on(system_name)) + ] + + +def get_unique_tool_names(system_name: Optional[str] = None) -> List[str]: + """Extract a list of unique tool names of all tools available.""" + return list( + set( + tool.name + for tool in get_available_tools() + if not system_name or tool.can_run_on(system_name) + ) + ) + + +class Tool(Backend): + """Class for representing a single tool component.""" + + def __init__(self, config: ToolConfig) -> None: + """Construct a Tool instance from a dict.""" + super().__init__(config) + + self.supported_systems = config.get("supported_systems", []) + + if "run" not in self.commands: + raise ConfigurationException("A Tool must have a 'run' command.") + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Tool): + return False + + return ( + super().__eq__(other) + and self.name == other.name + and set(self.supported_systems) == set(other.supported_systems) + ) + + def can_run_on(self, system_name: str) -> bool: + """Check if the tool can run on the system passed as argument.""" + return system_name in self.supported_systems + + def get_details(self) -> Dict[str, Any]: + """Return dictionary with all relevant information of the Tool instance.""" + output = { + "type": "tool", + "name": self.name, + "description": self.description, + "supported_systems": self.supported_systems, + "commands": self._get_command_details(), + } + + return output + + +def load_tools(config: ExtendedToolConfig) -> List[Tool]: + """Load tool. + + Tool configuration could contain different parameters/commands for different + supported systems. For each supported system this function will return separate + Tool instance with appropriate configuration. + """ + configs = load_application_or_tool_configs( + config, ToolConfig, is_system_required=False + ) + tools = [Tool(cfg) for cfg in configs] + return tools diff --git a/src/aiet/cli/__init__.py b/src/aiet/cli/__init__.py new file mode 100644 index 0000000..bcd17c3 --- /dev/null +++ b/src/aiet/cli/__init__.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module to mange the CLI interface.""" +import click + +from aiet import __version__ +from aiet.cli.application import application_cmd +from aiet.cli.completion import completion_cmd +from aiet.cli.system import system_cmd +from aiet.cli.tool import tool_cmd +from aiet.utils.helpers import set_verbosity + + +@click.group() +@click.version_option(__version__) +@click.option( + "-v", "--verbose", default=0, count=True, callback=set_verbosity, expose_value=False +) +@click.pass_context +def cli(ctx: click.Context) -> None: # pylint: disable=unused-argument + """AIET: AI Evaluation Toolkit.""" + # Unused arguments must be present here in definition to pass click context. + + +cli.add_command(application_cmd) +cli.add_command(system_cmd) +cli.add_command(tool_cmd) +cli.add_command(completion_cmd) diff --git a/src/aiet/cli/application.py b/src/aiet/cli/application.py new file mode 100644 index 0000000..59b652d --- /dev/null +++ b/src/aiet/cli/application.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright (c) 2021, Gianluca Gippetto. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause +"""Module to manage the CLI interface of applications.""" +import json +import logging +import re +from pathlib import Path +from typing import Any +from typing import IO +from typing import List +from typing import Optional +from typing import Tuple + +import click +import cloup + +from aiet.backend.application import get_application +from aiet.backend.application import get_available_application_directory_names +from aiet.backend.application import get_unique_application_names +from aiet.backend.application import install_application +from aiet.backend.application import remove_application +from aiet.backend.common import DataPaths +from aiet.backend.execution import execute_application_command +from aiet.backend.execution import run_application +from aiet.backend.system import get_available_systems +from aiet.cli.common import get_format +from aiet.cli.common import middleware_exception_handler +from aiet.cli.common import middleware_signal_handler +from aiet.cli.common import print_command_details +from aiet.cli.common import set_format + + +@click.group(name="application") +@click.option( + "-f", + "--format", + "format_", + type=click.Choice(["cli", "json"]), + default="cli", + show_default=True, +) +@click.pass_context +def application_cmd(ctx: click.Context, format_: str) -> None: + """Sub command to manage applications.""" + set_format(ctx, format_) + + +@application_cmd.command(name="list") +@click.pass_context +@click.option( + "-s", + "--system", + "system_name", + type=click.Choice([s.name for s in get_available_systems()]), + required=False, +) +def list_cmd(ctx: click.Context, system_name: str) -> None: + """List all available applications.""" + unique_application_names = get_unique_application_names(system_name) + unique_application_names.sort() + if get_format(ctx) == "json": + data = {"type": "application", "available": unique_application_names} + print(json.dumps(data)) + else: + print("Available applications:\n") + print(*unique_application_names, sep="\n") + + +@application_cmd.command(name="details") +@click.option( + "-n", + "--name", + "application_name", + type=click.Choice(get_unique_application_names()), + required=True, +) +@click.option( + "-s", + "--system", + "system_name", + type=click.Choice([s.name for s in get_available_systems()]), + required=False, +) +@click.pass_context +def details_cmd(ctx: click.Context, application_name: str, system_name: str) -> None: + """Details of a specific application.""" + applications = get_application(application_name, system_name) + if not applications: + raise click.UsageError( + "Application '{}' doesn't support the system '{}'".format( + application_name, system_name + ) + ) + + if get_format(ctx) == "json": + applications_details = [s.get_details() for s in applications] + print(json.dumps(applications_details)) + else: + for application in applications: + application_details = application.get_details() + application_details_template = ( + 'Application "{name}" details\nDescription: {description}' + ) + + print( + application_details_template.format( + name=application_details["name"], + description=application_details["description"], + ) + ) + + print( + "\nSupported systems: {}".format( + ", ".join(application_details["supported_systems"]) + ) + ) + + command_details = application_details["commands"] + + for command, details in command_details.items(): + print("\n{} commands:".format(command)) + print_command_details(details) + + +# pylint: disable=too-many-arguments +@application_cmd.command(name="execute") +@click.option( + "-n", + "--name", + "application_name", + type=click.Choice(get_unique_application_names()), + required=True, +) +@click.option( + "-s", + "--system", + "system_name", + type=click.Choice([s.name for s in get_available_systems()]), + required=True, +) +@click.option( + "-c", + "--command", + "command_name", + type=click.Choice(["build", "run"]), + required=True, +) +@click.option("-p", "--param", "application_params", multiple=True) +@click.option("--system-param", "system_params", multiple=True) +@click.option("-d", "--deploy", "deploy_params", multiple=True) +@middleware_signal_handler +@middleware_exception_handler +def execute_cmd( + application_name: str, + system_name: str, + command_name: str, + application_params: List[str], + system_params: List[str], + deploy_params: List[str], +) -> None: + """Execute application commands. DEPRECATED! Use 'aiet application run' instead.""" + logging.warning( + "Please use 'aiet application run' instead. Use of 'aiet application " + "execute' is deprecated and might be removed in a future release." + ) + + custom_deploy_data = get_custom_deploy_data(command_name, deploy_params) + + execute_application_command( + command_name, + application_name, + application_params, + system_name, + system_params, + custom_deploy_data, + ) + + +@cloup.command(name="run") +@cloup.option( + "-n", + "--name", + "application_name", + type=click.Choice(get_unique_application_names()), +) +@cloup.option( + "-s", + "--system", + "system_name", + type=click.Choice([s.name for s in get_available_systems()]), +) +@cloup.option("-p", "--param", "application_params", multiple=True) +@cloup.option("--system-param", "system_params", multiple=True) +@cloup.option("-d", "--deploy", "deploy_params", multiple=True) +@click.option( + "-r", + "--report", + "report_file", + type=Path, + help="Create a report file in JSON format containing metrics parsed from " + "the simulation output as specified in the aiet-config.json.", +) +@cloup.option( + "--config", + "config_file", + type=click.File("r"), + help="Read options from a config file rather than from the command line. " + "The config file is a json file.", +) +@cloup.constraint( + cloup.constraints.If( + cloup.constraints.conditions.Not( + cloup.constraints.conditions.IsSet("config_file") + ), + then=cloup.constraints.require_all, + ), + ["system_name", "application_name"], +) +@cloup.constraint( + cloup.constraints.If("config_file", then=cloup.constraints.accept_none), + [ + "system_name", + "application_name", + "application_params", + "system_params", + "deploy_params", + ], +) +@middleware_signal_handler +@middleware_exception_handler +def run_cmd( + application_name: str, + system_name: str, + application_params: List[str], + system_params: List[str], + deploy_params: List[str], + report_file: Optional[Path], + config_file: Optional[IO[str]], +) -> None: + """Execute application commands.""" + if config_file: + payload_data = json.load(config_file) + ( + system_name, + application_name, + application_params, + system_params, + deploy_params, + report_file, + ) = parse_payload_run_config(payload_data) + + custom_deploy_data = get_custom_deploy_data("run", deploy_params) + + run_application( + application_name, + application_params, + system_name, + system_params, + custom_deploy_data, + report_file, + ) + + +application_cmd.add_command(run_cmd) + + +def parse_payload_run_config( + payload_data: dict, +) -> Tuple[str, str, List[str], List[str], List[str], Optional[Path]]: + """Parse the payload into a tuple.""" + system_id = payload_data.get("id") + arguments: Optional[Any] = payload_data.get("arguments") + + if not isinstance(system_id, str): + raise click.ClickException("invalid payload json: no system 'id'") + if not isinstance(arguments, dict): + raise click.ClickException("invalid payload json: no arguments object") + + application_name = arguments.pop("application", None) + if not isinstance(application_name, str): + raise click.ClickException("invalid payload json: no application_id") + + report_path = arguments.pop("report_path", None) + + application_params = [] + system_params = [] + deploy_params = [] + + for (param_key, value) in arguments.items(): + (par, _) = re.subn("^application/", "", param_key) + (par, found_sys_param) = re.subn("^system/", "", par) + (par, found_deploy_param) = re.subn("^deploy/", "", par) + + param_expr = par + "=" + value + if found_sys_param: + system_params.append(param_expr) + elif found_deploy_param: + deploy_params.append(par) + else: + application_params.append(param_expr) + + return ( + system_id, + application_name, + application_params, + system_params, + deploy_params, + report_path, + ) + + +def get_custom_deploy_data( + command_name: str, deploy_params: List[str] +) -> List[DataPaths]: + """Get custom deploy data information.""" + custom_deploy_data: List[DataPaths] = [] + if not deploy_params: + return custom_deploy_data + + for param in deploy_params: + parts = param.split(":") + if not len(parts) == 2 or any(not part.strip() for part in parts): + raise click.ClickException( + "Invalid deploy parameter '{}' for command {}".format( + param, command_name + ) + ) + data_path = DataPaths(Path(parts[0]), parts[1]) + if not data_path.src.exists(): + raise click.ClickException("Path {} does not exist".format(data_path.src)) + custom_deploy_data.append(data_path) + + return custom_deploy_data + + +@application_cmd.command(name="install") +@click.option( + "-s", + "--source", + "source", + required=True, + help="Path to the directory or archive with application definition", +) +def install_cmd(source: str) -> None: + """Install new application.""" + source_path = Path(source) + install_application(source_path) + + +@application_cmd.command(name="remove") +@click.option( + "-d", + "--directory_name", + "directory_name", + type=click.Choice(get_available_application_directory_names()), + required=True, + help="Name of the directory with application", +) +def remove_cmd(directory_name: str) -> None: + """Remove application.""" + remove_application(directory_name) diff --git a/src/aiet/cli/common.py b/src/aiet/cli/common.py new file mode 100644 index 0000000..1d157b6 --- /dev/null +++ b/src/aiet/cli/common.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Common functions for cli module.""" +import enum +import logging +from functools import wraps +from signal import SIG_IGN +from signal import SIGINT +from signal import signal as signal_handler +from signal import SIGTERM +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict + +from click import ClickException +from click import Context +from click import UsageError + +from aiet.backend.common import ConfigurationException +from aiet.backend.execution import AnotherInstanceIsRunningException +from aiet.backend.execution import ConnectionException +from aiet.backend.protocol import SSHConnectionException +from aiet.utils.proc import CommandFailedException + + +class MiddlewareExitCode(enum.IntEnum): + """Middleware exit codes.""" + + SUCCESS = 0 + # exit codes 1 and 2 are used by click + SHUTDOWN_REQUESTED = 3 + BACKEND_ERROR = 4 + CONCURRENT_ERROR = 5 + CONNECTION_ERROR = 6 + CONFIGURATION_ERROR = 7 + MODEL_OPTIMISED_ERROR = 8 + INVALID_TFLITE_FILE_ERROR = 9 + + +class CustomClickException(ClickException): + """Custom click exception.""" + + def show(self, file: Any = None) -> None: + """Override show method.""" + super().show(file) + + logging.debug("Execution failed with following exception: ", exc_info=self) + + +class MiddlewareShutdownException(CustomClickException): + """Exception indicates that user requested middleware shutdown.""" + + exit_code = int(MiddlewareExitCode.SHUTDOWN_REQUESTED) + + +class BackendException(CustomClickException): + """Exception indicates that command failed.""" + + exit_code = int(MiddlewareExitCode.BACKEND_ERROR) + + +class ConcurrentErrorException(CustomClickException): + """Exception indicates concurrent execution error.""" + + exit_code = int(MiddlewareExitCode.CONCURRENT_ERROR) + + +class BackendConnectionException(CustomClickException): + """Exception indicates that connection could not be established.""" + + exit_code = int(MiddlewareExitCode.CONNECTION_ERROR) + + +class BackendConfigurationException(CustomClickException): + """Exception indicates some configuration issue.""" + + exit_code = int(MiddlewareExitCode.CONFIGURATION_ERROR) + + +class ModelOptimisedException(CustomClickException): + """Exception indicates input file has previously been Vela optimised.""" + + exit_code = int(MiddlewareExitCode.MODEL_OPTIMISED_ERROR) + + +class InvalidTFLiteFileError(CustomClickException): + """Exception indicates input TFLite file is misformatted.""" + + exit_code = int(MiddlewareExitCode.INVALID_TFLITE_FILE_ERROR) + + +def print_command_details(command: Dict) -> None: + """Print command details including parameters.""" + command_strings = command["command_strings"] + print("Commands: {}".format(command_strings)) + user_params = command["user_params"] + for i, param in enumerate(user_params, 1): + print("User parameter #{}".format(i)) + print("\tName: {}".format(param.get("name", "-"))) + print("\tDescription: {}".format(param["description"])) + print("\tPossible values: {}".format(param.get("values", "-"))) + print("\tDefault value: {}".format(param.get("default_value", "-"))) + print("\tAlias: {}".format(param.get("alias", "-"))) + + +def raise_exception_at_signal( + signum: int, frame: Any # pylint: disable=unused-argument +) -> None: + """Handle signals.""" + # Disable both SIGINT and SIGTERM signals. Further SIGINT and SIGTERM + # signals will be ignored as we allow a graceful shutdown. + # Unused arguments must be present here in definition as used in signal handler + # callback + + signal_handler(SIGINT, SIG_IGN) + signal_handler(SIGTERM, SIG_IGN) + raise MiddlewareShutdownException("Middleware shutdown requested") + + +def middleware_exception_handler(func: Callable) -> Callable: + """Handle backend exceptions decorator.""" + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return func(*args, **kwargs) + except (MiddlewareShutdownException, UsageError, ClickException) as error: + # click should take care of these exceptions + raise error + except ValueError as error: + raise ClickException(str(error)) from error + except AnotherInstanceIsRunningException as error: + raise ConcurrentErrorException( + "Another instance of the system is running" + ) from error + except (SSHConnectionException, ConnectionException) as error: + raise BackendConnectionException(str(error)) from error + except ConfigurationException as error: + raise BackendConfigurationException(str(error)) from error + except (CommandFailedException, Exception) as error: + raise BackendException( + "Execution failed. Please check output for the details." + ) from error + + return wrapper + + +def middleware_signal_handler(func: Callable) -> Callable: + """Handle signals decorator.""" + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + # Set up signal handlers for SIGINT (ctrl-c) and SIGTERM (kill command) + # The handler ignores further signals and it raises an exception + signal_handler(SIGINT, raise_exception_at_signal) + signal_handler(SIGTERM, raise_exception_at_signal) + + return func(*args, **kwargs) + + return wrapper + + +def set_format(ctx: Context, format_: str) -> None: + """Save format in click context.""" + ctx_obj = ctx.ensure_object(dict) + ctx_obj["format"] = format_ + + +def get_format(ctx: Context) -> str: + """Get format from click context.""" + ctx_obj = cast(Dict[str, str], ctx.ensure_object(dict)) + return ctx_obj["format"] diff --git a/src/aiet/cli/completion.py b/src/aiet/cli/completion.py new file mode 100644 index 0000000..71f054f --- /dev/null +++ b/src/aiet/cli/completion.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +""" +Add auto completion to different shells with these helpers. + +See: https://click.palletsprojects.com/en/8.0.x/shell-completion/ +""" +import click + + +def _get_package_name() -> str: + return __name__.split(".", maxsplit=1)[0] + + +# aiet completion bash +@click.group(name="completion") +def completion_cmd() -> None: + """Enable auto completion for your shell.""" + + +@completion_cmd.command(name="bash") +def bash_cmd() -> None: + """ + Enable auto completion for bash. + + Use this command to activate completion in the current bash: + + eval "`aiet completion bash`" + + Use this command to add auto completion to bash globally, if you have aiet + installed globally (requires starting a new shell afterwards): + + aiet completion bash >> ~/.bashrc + """ + package_name = _get_package_name() + print(f'eval "$(_{package_name.upper()}_COMPLETE=bash_source {package_name})"') + + +@completion_cmd.command(name="zsh") +def zsh_cmd() -> None: + """ + Enable auto completion for zsh. + + Use this command to activate completion in the current zsh: + + eval "`aiet completion zsh`" + + Use this command to add auto completion to zsh globally, if you have aiet + installed globally (requires starting a new shell afterwards): + + aiet completion zsh >> ~/.zshrc + """ + package_name = _get_package_name() + print(f'eval "$(_{package_name.upper()}_COMPLETE=zsh_source {package_name})"') + + +@completion_cmd.command(name="fish") +def fish_cmd() -> None: + """ + Enable auto completion for fish. + + Use this command to activate completion in the current fish: + + eval "`aiet completion fish`" + + Use this command to add auto completion to fish globally, if you have aiet + installed globally (requires starting a new shell afterwards): + + aiet completion fish >> ~/.config/fish/completions/aiet.fish + """ + package_name = _get_package_name() + print(f'eval "(env _{package_name.upper()}_COMPLETE=fish_source {package_name})"') diff --git a/src/aiet/cli/system.py b/src/aiet/cli/system.py new file mode 100644 index 0000000..f1f7637 --- /dev/null +++ b/src/aiet/cli/system.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module to manage the CLI interface of systems.""" +import json +from pathlib import Path +from typing import cast + +import click + +from aiet.backend.application import get_available_applications +from aiet.backend.system import get_available_systems +from aiet.backend.system import get_available_systems_directory_names +from aiet.backend.system import get_system +from aiet.backend.system import install_system +from aiet.backend.system import remove_system +from aiet.backend.system import System +from aiet.cli.common import get_format +from aiet.cli.common import print_command_details +from aiet.cli.common import set_format + + +@click.group(name="system") +@click.option( + "-f", + "--format", + "format_", + type=click.Choice(["cli", "json"]), + default="cli", + show_default=True, +) +@click.pass_context +def system_cmd(ctx: click.Context, format_: str) -> None: + """Sub command to manage systems.""" + set_format(ctx, format_) + + +@system_cmd.command(name="list") +@click.pass_context +def list_cmd(ctx: click.Context) -> None: + """List all available systems.""" + available_systems = get_available_systems() + system_names = [system.name for system in available_systems] + if get_format(ctx) == "json": + data = {"type": "system", "available": system_names} + print(json.dumps(data)) + else: + print("Available systems:\n") + print(*system_names, sep="\n") + + +@system_cmd.command(name="details") +@click.option( + "-n", + "--name", + "system_name", + type=click.Choice([s.name for s in get_available_systems()]), + required=True, +) +@click.pass_context +def details_cmd(ctx: click.Context, system_name: str) -> None: + """Details of a specific system.""" + system = cast(System, get_system(system_name)) + applications = [ + s.name for s in get_available_applications() if s.can_run_on(system.name) + ] + system_details = system.get_details() + if get_format(ctx) == "json": + system_details["available_application"] = applications + print(json.dumps(system_details)) + else: + system_details_template = ( + 'System "{name}" details\n' + "Description: {description}\n" + "Data Transfer Protocol: {protocol}\n" + "Available Applications: {available_application}" + ) + print( + system_details_template.format( + name=system_details["name"], + description=system_details["description"], + protocol=system_details["data_transfer_protocol"], + available_application=", ".join(applications), + ) + ) + + if system_details["annotations"]: + print("Annotations:") + for ann_name, ann_value in system_details["annotations"].items(): + print("\t{}: {}".format(ann_name, ann_value)) + + command_details = system_details["commands"] + for command, details in command_details.items(): + print("\n{} commands:".format(command)) + print_command_details(details) + + +@system_cmd.command(name="install") +@click.option( + "-s", + "--source", + "source", + required=True, + help="Path to the directory or archive with system definition", +) +def install_cmd(source: str) -> None: + """Install new system.""" + source_path = Path(source) + install_system(source_path) + + +@system_cmd.command(name="remove") +@click.option( + "-d", + "--directory_name", + "directory_name", + type=click.Choice(get_available_systems_directory_names()), + required=True, + help="Name of the directory with system", +) +def remove_cmd(directory_name: str) -> None: + """Remove system by given name.""" + remove_system(directory_name) diff --git a/src/aiet/cli/tool.py b/src/aiet/cli/tool.py new file mode 100644 index 0000000..2c80821 --- /dev/null +++ b/src/aiet/cli/tool.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module to manage the CLI interface of tools.""" +import json +from typing import Any +from typing import List +from typing import Optional + +import click + +from aiet.backend.execution import execute_tool_command +from aiet.backend.tool import get_tool +from aiet.backend.tool import get_unique_tool_names +from aiet.cli.common import get_format +from aiet.cli.common import middleware_exception_handler +from aiet.cli.common import middleware_signal_handler +from aiet.cli.common import print_command_details +from aiet.cli.common import set_format + + +@click.group(name="tool") +@click.option( + "-f", + "--format", + "format_", + type=click.Choice(["cli", "json"]), + default="cli", + show_default=True, +) +@click.pass_context +def tool_cmd(ctx: click.Context, format_: str) -> None: + """Sub command to manage tools.""" + set_format(ctx, format_) + + +@tool_cmd.command(name="list") +@click.pass_context +def list_cmd(ctx: click.Context) -> None: + """List all available tools.""" + # raise NotImplementedError("TODO") + tool_names = get_unique_tool_names() + tool_names.sort() + if get_format(ctx) == "json": + data = {"type": "tool", "available": tool_names} + print(json.dumps(data)) + else: + print("Available tools:\n") + print(*tool_names, sep="\n") + + +def validate_system( + ctx: click.Context, + _: click.Parameter, # param is not used + value: Any, +) -> Any: + """Validate provided system name depending on the the tool name.""" + tool_name = ctx.params["tool_name"] + tools = get_tool(tool_name, value) + if not tools: + supported_systems = [tool.supported_systems[0] for tool in get_tool(tool_name)] + raise click.BadParameter( + message="'{}' is not one of {}.".format( + value, + ", ".join("'{}'".format(system) for system in supported_systems), + ), + ctx=ctx, + ) + return value + + +@tool_cmd.command(name="details") +@click.option( + "-n", + "--name", + "tool_name", + type=click.Choice(get_unique_tool_names()), + required=True, +) +@click.option( + "-s", + "--system", + "system_name", + callback=validate_system, + required=False, +) +@click.pass_context +@middleware_signal_handler +@middleware_exception_handler +def details_cmd(ctx: click.Context, tool_name: str, system_name: Optional[str]) -> None: + """Details of a specific tool.""" + tools = get_tool(tool_name, system_name) + if get_format(ctx) == "json": + tools_details = [s.get_details() for s in tools] + print(json.dumps(tools_details)) + else: + for tool in tools: + tool_details = tool.get_details() + tool_details_template = 'Tool "{name}" details\nDescription: {description}' + + print( + tool_details_template.format( + name=tool_details["name"], + description=tool_details["description"], + ) + ) + + print( + "\nSupported systems: {}".format( + ", ".join(tool_details["supported_systems"]) + ) + ) + + command_details = tool_details["commands"] + + for command, details in command_details.items(): + print("\n{} commands:".format(command)) + print_command_details(details) + + +# pylint: disable=too-many-arguments +@tool_cmd.command(name="execute") +@click.option( + "-n", + "--name", + "tool_name", + type=click.Choice(get_unique_tool_names()), + required=True, +) +@click.option("-p", "--param", "tool_params", multiple=True) +@click.option( + "-s", + "--system", + "system_name", + callback=validate_system, + required=False, +) +@middleware_signal_handler +@middleware_exception_handler +def execute_cmd( + tool_name: str, tool_params: List[str], system_name: Optional[str] +) -> None: + """Execute tool commands.""" + execute_tool_command(tool_name, tool_params, system_name) diff --git a/src/aiet/main.py b/src/aiet/main.py new file mode 100644 index 0000000..6898ad9 --- /dev/null +++ b/src/aiet/main.py @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Entry point module of AIET.""" +from aiet.cli import cli + + +def main() -> None: + """Entry point of aiet application.""" + cli() # pylint: disable=no-value-for-parameter + + +if __name__ == "__main__": + main() diff --git a/src/aiet/resources/applications/.gitignore b/src/aiet/resources/applications/.gitignore new file mode 100644 index 0000000..0226166 --- /dev/null +++ b/src/aiet/resources/applications/.gitignore @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/src/aiet/resources/systems/.gitignore b/src/aiet/resources/systems/.gitignore new file mode 100644 index 0000000..0226166 --- /dev/null +++ b/src/aiet/resources/systems/.gitignore @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/src/aiet/resources/tools/vela/aiet-config.json b/src/aiet/resources/tools/vela/aiet-config.json new file mode 100644 index 0000000..c12f291 --- /dev/null +++ b/src/aiet/resources/tools/vela/aiet-config.json @@ -0,0 +1,73 @@ +[ + { + "name": "vela", + "description": "Neural network model compiler for Arm Ethos-U NPUs", + "supported_systems": [ + { + "name": "Corstone-300: Cortex-M55+Ethos-U55" + }, + { + "name": "Corstone-310: Cortex-M85+Ethos-U55" + }, + { + "name": "Corstone-300: Cortex-M55+Ethos-U65", + "variables": { + "accelerator_config_prefix": "ethos-u65", + "system_config": "Ethos_U65_High_End", + "shared_sram": "U65_Shared_Sram" + }, + "user_params": { + "run": [ + { + "description": "MACs per cycle", + "values": [ + "256", + "512" + ], + "default_value": "512", + "alias": "mac" + } + ] + } + } + ], + "variables": { + "accelerator_config_prefix": "ethos-u55", + "system_config": "Ethos_U55_High_End_Embedded", + "shared_sram": "U55_Shared_Sram" + }, + "commands": { + "run": [ + "run_vela {user_params:input} {user_params:output} --config {tool.config_dir}/vela.ini --accelerator-config {variables:accelerator_config_prefix}-{user_params:mac} --system-config {variables:system_config} --memory-mode {variables:shared_sram} --optimise Performance" + ] + }, + "user_params": { + "run": [ + { + "description": "MACs per cycle", + "values": [ + "32", + "64", + "128", + "256" + ], + "default_value": "128", + "alias": "mac" + }, + { + "name": "--input-model", + "description": "Path to the TFLite model", + "values": [], + "alias": "input" + }, + { + "name": "--output-model", + "description": "Path to the output model file of the vela-optimisation step. The vela output is saved in the parent directory.", + "values": [], + "default_value": "output_model.tflite", + "alias": "output" + } + ] + } + } +] diff --git a/src/aiet/resources/tools/vela/aiet-config.json.license b/src/aiet/resources/tools/vela/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/src/aiet/resources/tools/vela/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/src/aiet/resources/tools/vela/check_model.py b/src/aiet/resources/tools/vela/check_model.py new file mode 100644 index 0000000..7c700b1 --- /dev/null +++ b/src/aiet/resources/tools/vela/check_model.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Check if a TFLite model file is Vela-optimised.""" +import struct +from pathlib import Path + +from ethosu.vela.tflite.Model import Model + +from aiet.cli.common import InvalidTFLiteFileError +from aiet.cli.common import ModelOptimisedException +from aiet.utils.fs import read_file_as_bytearray + + +def get_model_from_file(input_model_file: Path) -> Model: + """Generate Model instance from TFLite file using flatc generated code.""" + buffer = read_file_as_bytearray(input_model_file) + try: + model = Model.GetRootAsModel(buffer, 0) + except (TypeError, RuntimeError, struct.error) as tflite_error: + raise InvalidTFLiteFileError( + f"Error reading in model from {input_model_file}." + ) from tflite_error + return model + + +def is_vela_optimised(tflite_model: Model) -> bool: + """Return True if 'ethos-u' custom operator found in the Model.""" + operators = get_operators_from_model(tflite_model) + + custom_codes = get_custom_codes_from_operators(operators) + + return check_custom_codes_for_ethosu(custom_codes) + + +def get_operators_from_model(tflite_model: Model) -> list: + """Return list of the unique operator codes used in the Model.""" + return [ + tflite_model.OperatorCodes(index) + for index in range(tflite_model.OperatorCodesLength()) + ] + + +def get_custom_codes_from_operators(operators: list) -> list: + """Return list of each operator's CustomCode() strings, if they exist.""" + return [ + operator.CustomCode() + for operator in operators + if operator.CustomCode() is not None + ] + + +def check_custom_codes_for_ethosu(custom_codes: list) -> bool: + """Check for existence of ethos-u string in the custom codes.""" + return any( + custom_code_name.decode("utf-8") == "ethos-u" + for custom_code_name in custom_codes + ) + + +def check_model(tflite_file_name: str) -> None: + """Raise an exception if model in given file is Vela optimised.""" + tflite_path = Path(tflite_file_name) + + tflite_model = get_model_from_file(tflite_path) + + if is_vela_optimised(tflite_model): + raise ModelOptimisedException( + f"TFLite model in {tflite_file_name} is already " + f"vela optimised ('ethos-u' custom op detected)." + ) + + print( + f"TFLite model in {tflite_file_name} is not vela optimised " + f"('ethos-u' custom op not detected)." + ) diff --git a/src/aiet/resources/tools/vela/run_vela.py b/src/aiet/resources/tools/vela/run_vela.py new file mode 100644 index 0000000..2c1b0be --- /dev/null +++ b/src/aiet/resources/tools/vela/run_vela.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Wrapper to only run Vela when the input is not already optimised.""" +import shutil +import subprocess +from pathlib import Path +from typing import Tuple + +import click + +from aiet.cli.common import ModelOptimisedException +from aiet.resources.tools.vela.check_model import check_model + + +def vela_output_model_path(input_model: str, output_dir: str) -> Path: + """Construct the path to the Vela output file.""" + in_path = Path(input_model) + tflite_vela = Path(output_dir) / f"{in_path.stem}_vela{in_path.suffix}" + return tflite_vela + + +def execute_vela(vela_args: Tuple, output_dir: Path, input_model: str) -> None: + """Execute vela as external call.""" + cmd = ["vela"] + list(vela_args) + cmd += ["--output-dir", str(output_dir)] # Re-add parsed out_dir to arguments + cmd += [input_model] + subprocess.run(cmd, check=True) + + +@click.command(context_settings=dict(ignore_unknown_options=True)) +@click.option( + "--input-model", + "-i", + type=click.Path(exists=True, file_okay=True, readable=True), + required=True, +) +@click.option("--output-model", "-o", type=click.Path(), required=True) +# Collect the remaining arguments to be directly forwarded to Vela +@click.argument("vela-args", nargs=-1, type=click.UNPROCESSED) +def run_vela(input_model: str, output_model: str, vela_args: Tuple) -> None: + """Check input, run Vela (if needed) and copy optimised file to destination.""" + output_dir = Path(output_model).parent + try: + check_model(input_model) # raises an exception if already Vela-optimised + execute_vela(vela_args, output_dir, input_model) + print("Vela optimisation complete.") + src_model = vela_output_model_path(input_model, str(output_dir)) + except ModelOptimisedException as ex: + # Input already optimized: copy input file to destination path and return + print(f"Input already vela-optimised.\n{ex}") + src_model = Path(input_model) + except subprocess.CalledProcessError as ex: + print(ex) + raise SystemExit(ex.returncode) from ex + + try: + shutil.copyfile(src_model, output_model) + except (shutil.SameFileError, OSError) as ex: + print(ex) + raise SystemExit(ex.errno) from ex + + +def main() -> None: + """Entry point of check_model application.""" + run_vela() # pylint: disable=no-value-for-parameter diff --git a/src/aiet/resources/tools/vela/vela.ini b/src/aiet/resources/tools/vela/vela.ini new file mode 100644 index 0000000..5996553 --- /dev/null +++ b/src/aiet/resources/tools/vela/vela.ini @@ -0,0 +1,53 @@ +; SPDX-FileCopyrightText: Copyright 2021-2022, Arm Limited and/or its affiliates. +; SPDX-License-Identifier: Apache-2.0 + +; ----------------------------------------------------------------------------- +; Vela configuration file + +; ----------------------------------------------------------------------------- +; System Configuration + +; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s) +[System_Config.Ethos_U55_High_End_Embedded] +core_clock=500e6 +axi0_port=Sram +axi1_port=OffChipFlash +Sram_clock_scale=1.0 +Sram_burst_length=32 +Sram_read_latency=32 +Sram_write_latency=32 +OffChipFlash_clock_scale=0.125 +OffChipFlash_burst_length=128 +OffChipFlash_read_latency=64 +OffChipFlash_write_latency=64 + +; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s) +[System_Config.Ethos_U65_High_End] +core_clock=1e9 +axi0_port=Sram +axi1_port=Dram +Sram_clock_scale=1.0 +Sram_burst_length=32 +Sram_read_latency=32 +Sram_write_latency=32 +Dram_clock_scale=0.234375 +Dram_burst_length=128 +Dram_read_latency=500 +Dram_write_latency=250 + +; ----------------------------------------------------------------------------- +; Memory Mode + +; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software +; The non-SRAM memory is assumed to be read-only +[Memory_Mode.U55_Shared_Sram] +const_mem_area=Axi1 +arena_mem_area=Axi0 +cache_mem_area=Axi0 +arena_cache_size=4194304 + +[Memory_Mode.U65_Shared_Sram] +const_mem_area=Axi1 +arena_mem_area=Axi0 +cache_mem_area=Axi0 +arena_cache_size=2097152 diff --git a/src/aiet/utils/__init__.py b/src/aiet/utils/__init__.py new file mode 100644 index 0000000..fc7ef7c --- /dev/null +++ b/src/aiet/utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""This module contains all utils shared across aiet project.""" diff --git a/src/aiet/utils/fs.py b/src/aiet/utils/fs.py new file mode 100644 index 0000000..ea99a69 --- /dev/null +++ b/src/aiet/utils/fs.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module to host all file system related functions.""" +import importlib.resources as pkg_resources +import re +import shutil +from pathlib import Path +from typing import Any +from typing import Literal +from typing import Optional + +ResourceType = Literal["applications", "systems", "tools"] + + +def get_aiet_resources() -> Path: + """Get resources folder path.""" + with pkg_resources.path("aiet", "__init__.py") as init_path: + project_root = init_path.parent + return project_root / "resources" + + +def get_resources(name: ResourceType) -> Path: + """Return the absolute path of the specified resource. + + It uses importlib to return resources packaged with MANIFEST.in. + """ + if not name: + raise ResourceWarning("Resource name is not provided") + + resource_path = get_aiet_resources() / name + if resource_path.is_dir(): + return resource_path + + raise ResourceWarning("Resource '{}' not found.".format(name)) + + +def copy_directory_content(source: Path, destination: Path) -> None: + """Copy content of the source directory into destination directory.""" + for item in source.iterdir(): + src = source / item.name + dest = destination / item.name + + if src.is_dir(): + shutil.copytree(src, dest) + else: + shutil.copy2(src, dest) + + +def remove_resource(resource_directory: str, resource_type: ResourceType) -> None: + """Remove resource data.""" + resources = get_resources(resource_type) + + resource_location = resources / resource_directory + if not resource_location.exists(): + raise Exception("Resource {} does not exist".format(resource_directory)) + + if not resource_location.is_dir(): + raise Exception("Wrong resource {}".format(resource_directory)) + + shutil.rmtree(resource_location) + + +def remove_directory(directory_path: Optional[Path]) -> None: + """Remove directory.""" + if not directory_path or not directory_path.is_dir(): + raise Exception("No directory path provided") + + shutil.rmtree(directory_path) + + +def recreate_directory(directory_path: Optional[Path]) -> None: + """Recreate directory.""" + if not directory_path: + raise Exception("No directory path provided") + + if directory_path.exists() and not directory_path.is_dir(): + raise Exception( + "Path {} does exist and it is not a directory".format(str(directory_path)) + ) + + if directory_path.is_dir(): + remove_directory(directory_path) + + directory_path.mkdir() + + +def read_file(file_path: Path, mode: Optional[str] = None) -> Any: + """Read file as string or bytearray.""" + if file_path.is_file(): + if mode is not None: + # Ignore pylint warning because mode can be 'binary' as well which + # is not compatible with specifying encodings. + with open(file_path, mode) as file: # pylint: disable=unspecified-encoding + return file.read() + else: + with open(file_path, encoding="utf-8") as file: + return file.read() + + if mode == "rb": + return b"" + return "" + + +def read_file_as_string(file_path: Path) -> str: + """Read file as string.""" + return str(read_file(file_path)) + + +def read_file_as_bytearray(file_path: Path) -> bytearray: + """Read a file as bytearray.""" + return bytearray(read_file(file_path, mode="rb")) + + +def valid_for_filename(value: str, replacement: str = "") -> str: + """Replace non alpha numeric characters.""" + return re.sub(r"[^\w.]", replacement, value, flags=re.ASCII) diff --git a/src/aiet/utils/helpers.py b/src/aiet/utils/helpers.py new file mode 100644 index 0000000..6d3cd22 --- /dev/null +++ b/src/aiet/utils/helpers.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Helpers functions.""" +import logging +from typing import Any + + +def set_verbosity( + ctx: Any, option: Any, verbosity: Any # pylint: disable=unused-argument +) -> None: + """Set the logging level according to the verbosity.""" + # Unused arguments must be present here in definition as these are required in + # function definition when set as a callback + if verbosity == 1: + logging.getLogger().setLevel(logging.INFO) + elif verbosity > 1: + logging.getLogger().setLevel(logging.DEBUG) diff --git a/src/aiet/utils/proc.py b/src/aiet/utils/proc.py new file mode 100644 index 0000000..b6f4357 --- /dev/null +++ b/src/aiet/utils/proc.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Processes module. + +This module contains all classes and functions for dealing with Linux +processes. +""" +import csv +import datetime +import logging +import shlex +import signal +import time +from pathlib import Path +from typing import Any +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Tuple + +import psutil +from sh import Command +from sh import CommandNotFound +from sh import ErrorReturnCode +from sh import RunningCommand + +from aiet.utils.fs import valid_for_filename + + +class CommandFailedException(Exception): + """Exception for failed command execution.""" + + +class ShellCommand: + """Wrapper class for shell commands.""" + + def __init__(self, base_log_path: str = "/tmp") -> None: + """Initialise the class. + + base_log_path: it is the base directory where logs will be stored + """ + self.base_log_path = base_log_path + + def run( + self, + cmd: str, + *args: str, + _cwd: Optional[Path] = None, + _tee: bool = True, + _bg: bool = True, + _out: Any = None, + _err: Any = None, + _search_paths: Optional[List[Path]] = None + ) -> RunningCommand: + """Run the shell command with the given arguments. + + There are special arguments that modify the behaviour of the process. + _cwd: current working directory + _tee: it redirects the stdout both to console and file + _bg: if True, it runs the process in background and the command is not + blocking. + _out: use this object for stdout redirect, + _err: use this object for stderr redirect, + _search_paths: If presented used for searching executable + """ + try: + kwargs = {} + if _cwd: + kwargs["_cwd"] = str(_cwd) + command = Command(cmd, _search_paths).bake(args, **kwargs) + except CommandNotFound as error: + logging.error("Command '%s' not found", error.args[0]) + raise error + + out, err = _out, _err + if not _out and not _err: + out, err = [ + str(item) + for item in self.get_stdout_stderr_paths(self.base_log_path, cmd) + ] + + return command(_out=out, _err=err, _tee=_tee, _bg=_bg, _bg_exc=False) + + @classmethod + def get_stdout_stderr_paths(cls, base_log_path: str, cmd: str) -> Tuple[Path, Path]: + """Construct and returns the paths of stdout/stderr files.""" + timestamp = datetime.datetime.now().timestamp() + base_path = Path(base_log_path) + base = base_path / "{}_{}".format(valid_for_filename(cmd, "_"), timestamp) + stdout = base.with_suffix(".out") + stderr = base.with_suffix(".err") + try: + stdout.touch() + stderr.touch() + except FileNotFoundError as error: + logging.error("File not found: %s", error.filename) + raise error + return stdout, stderr + + +def parse_command(command: str, shell: str = "bash") -> List[str]: + """Parse command.""" + cmd, *args = shlex.split(command, posix=True) + + if is_shell_script(cmd): + args = [cmd] + args + cmd = shell + + return [cmd] + args + + +def get_stdout_stderr_paths( + command: str, base_log_path: str = "/tmp" +) -> Tuple[Path, Path]: + """Construct and returns the paths of stdout/stderr files.""" + cmd, *_ = parse_command(command) + + return ShellCommand.get_stdout_stderr_paths(base_log_path, cmd) + + +def execute_command( # pylint: disable=invalid-name + command: str, + cwd: Path, + bg: bool = False, + shell: str = "bash", + out: Any = None, + err: Any = None, +) -> RunningCommand: + """Execute shell command.""" + cmd, *args = parse_command(command, shell) + + search_paths = None + if cmd != shell and (cwd / cmd).is_file(): + search_paths = [cwd] + + return ShellCommand().run( + cmd, *args, _cwd=cwd, _bg=bg, _search_paths=search_paths, _out=out, _err=err + ) + + +def is_shell_script(cmd: str) -> bool: + """Check if command is shell script.""" + return cmd.endswith(".sh") + + +def run_and_wait( + command: str, + cwd: Path, + terminate_on_error: bool = False, + out: Any = None, + err: Any = None, +) -> Tuple[int, bytearray, bytearray]: + """ + Run command and wait while it is executing. + + Returns a tuple: (exit_code, stdout, stderr) + """ + running_cmd: Optional[RunningCommand] = None + try: + running_cmd = execute_command(command, cwd, bg=True, out=out, err=err) + return running_cmd.exit_code, running_cmd.stdout, running_cmd.stderr + except ErrorReturnCode as cmd_failed: + raise CommandFailedException() from cmd_failed + except Exception as error: + is_running = running_cmd is not None and running_cmd.is_alive() + if terminate_on_error and is_running: + print("Terminating ...") + terminate_command(running_cmd) + + raise error + + +def terminate_command( + running_cmd: RunningCommand, + wait: bool = True, + wait_period: float = 0.5, + number_of_attempts: int = 20, +) -> None: + """Terminate running command.""" + try: + running_cmd.process.signal_group(signal.SIGINT) + if wait: + for _ in range(number_of_attempts): + time.sleep(wait_period) + if not running_cmd.is_alive(): + return + print( + "Unable to terminate process {}. Sending SIGTERM...".format( + running_cmd.process.pid + ) + ) + running_cmd.process.signal_group(signal.SIGTERM) + except ProcessLookupError: + pass + + +def terminate_external_process( + process: psutil.Process, + wait_period: float = 0.5, + number_of_attempts: int = 20, + wait_for_termination: float = 5.0, +) -> None: + """Terminate external process.""" + try: + process.terminate() + for _ in range(number_of_attempts): + if not process.is_running(): + return + time.sleep(wait_period) + + if process.is_running(): + process.terminate() + time.sleep(wait_for_termination) + except psutil.Error: + print("Unable to terminate process") + + +class ProcessInfo(NamedTuple): + """Process information.""" + + name: str + executable: str + cwd: str + pid: int + + +def save_process_info(pid: int, pid_file: Path) -> None: + """Save process information to file.""" + try: + parent = psutil.Process(pid) + children = parent.children(recursive=True) + family = [parent] + children + + with open(pid_file, "w", encoding="utf-8") as file: + csv_writer = csv.writer(file) + for member in family: + process_info = ProcessInfo( + name=member.name(), + executable=member.exe(), + cwd=member.cwd(), + pid=member.pid, + ) + csv_writer.writerow(process_info) + except psutil.NoSuchProcess: + # if process does not exist or finishes before + # function call then nothing could be saved + # just ignore this exception and exit + pass + + +def read_process_info(pid_file: Path) -> List[ProcessInfo]: + """Read information about previous system processes.""" + if not pid_file.is_file(): + return [] + + result = [] + with open(pid_file, encoding="utf-8") as file: + csv_reader = csv.reader(file) + for row in csv_reader: + name, executable, cwd, pid = row + result.append( + ProcessInfo(name=name, executable=executable, cwd=cwd, pid=int(pid)) + ) + + return result + + +def print_command_stdout(command: RunningCommand) -> None: + """Print the stdout of a command. + + The command has 2 states: running and done. + If the command is running, the output is taken by the running process. + If the command has ended its execution, the stdout is taken from stdout + property + """ + if command.is_alive(): + while True: + try: + print(command.next(), end="") + except StopIteration: + break + else: + print(command.stdout) diff --git a/src/mlia/__init__.py b/src/mlia/__init__.py new file mode 100644 index 0000000..ed9ae87 --- /dev/null +++ b/src/mlia/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Init of MLIA.""" +import logging +import os + +import pkg_resources + +# redirect warnings to logging +logging.captureWarnings(True) + + +# as TensorFlow tries to configure root logger +# it should be configured before importing TensorFlow +root_logger = logging.getLogger() +root_logger.addHandler(logging.NullHandler()) + + +# disable TensorFlow warning messages +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +__version__ = pkg_resources.get_distribution("mlia").version diff --git a/src/mlia/api.py b/src/mlia/api.py new file mode 100644 index 0000000..53ea4c8 --- /dev/null +++ b/src/mlia/api.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for the API functions.""" +import logging +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import Union + +from mlia.core._typing import PathOrFileLike +from mlia.core.advisor import InferenceAdvisor +from mlia.core.common import AdviceCategory +from mlia.core.context import ExecutionContext +from mlia.core.events import EventHandler +from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor +from mlia.devices.ethosu.handlers import EthosUEventHandler + + +logger = logging.getLogger(__name__) + + +_DEFAULT_OPTIMIZATION_TARGETS = [ + { + "optimization_type": "pruning", + "optimization_target": 0.5, + "layers_to_optimize": None, + }, + { + "optimization_type": "clustering", + "optimization_target": 32, + "layers_to_optimize": None, + }, +] + + +def get_advice( + target_profile: str, + model: Union[Path, str], + category: Literal["all", "operators", "performance", "optimization"] = "all", + optimization_targets: Optional[List[Dict[str, Any]]] = None, + working_dir: Union[str, Path] = "mlia_output", + output: Optional[PathOrFileLike] = None, + context: Optional[ExecutionContext] = None, + backends: Optional[List[str]] = None, +) -> None: + """Get the advice. + + This function represents an entry point to the library API. + + Based on provided parameters it will collect and analyze the data + and produce the advice. + + :param target_profile: target profile identifier + :param model: path to the NN model + :param category: category of the advice. MLIA supports four categories: + "all", "operators", "performance", "optimization". If not provided + category "all" is used by default. + :param optimization_targets: optional model optimization targets that + could be used for generating advice in categories + "all" and "optimization." + :param working_dir: path to the directory that will be used for storing + intermediate files during execution (e.g. converted models) + :param output: path to the report file. If provided MLIA will save + report in this location. Format of the report automatically + detected based on file extension. + :param context: optional parameter which represents execution context, + could be used for advanced use cases + :param backends: A list of backends that should be used for the given + target. Default settings will be used if None. + + + Examples: + NB: Before launching MLIA, the logging functionality should be configured! + + Getting the advice for the provided target profile and the model + + >>> get_advice("ethos-u55-256", "path/to/the/model") + + Getting the advice for the category "performance" and save result report in file + "report.json" + + >>> get_advice("ethos-u55-256", "path/to/the/model", "performance", + output="report.json") + + """ + advice_category = AdviceCategory.from_string(category) + config_parameters = _get_config_parameters( + model, target_profile, backends, optimization_targets + ) + event_handlers = _get_event_handlers(output) + + if context is not None: + if context.advice_category is None: + context.advice_category = advice_category + + if context.config_parameters is None: + context.config_parameters = config_parameters + + if context.event_handlers is None: + context.event_handlers = event_handlers + + if context is None: + context = ExecutionContext( + advice_category=advice_category, + working_dir=working_dir, + config_parameters=config_parameters, + event_handlers=event_handlers, + ) + + advisor = _get_advisor(target_profile) + advisor.run(context) + + +def _get_advisor(target: Optional[str]) -> InferenceAdvisor: + """Find appropriate advisor for the target.""" + if not target: + raise Exception("Target is not provided") + + return EthosUInferenceAdvisor() + + +def _get_config_parameters( + model: Union[Path, str], + target_profile: str, + backends: Optional[List[str]], + optimization_targets: Optional[List[Dict[str, Any]]], +) -> Dict[str, Any]: + """Get configuration parameters for the advisor.""" + advisor_parameters: Dict[str, Any] = { + "ethos_u_inference_advisor": { + "model": model, + "device": { + "target_profile": target_profile, + }, + }, + } + # Specifying backends is optional (default is used) + if backends is not None: + advisor_parameters["ethos_u_inference_advisor"]["backends"] = backends + + if not optimization_targets: + optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS + + advisor_parameters.update( + { + "ethos_u_model_optimizations": { + "optimizations": [ + optimization_targets, + ], + }, + } + ) + + return advisor_parameters + + +def _get_event_handlers(output: Optional[PathOrFileLike]) -> List[EventHandler]: + """Return list of the event handlers.""" + return [EthosUEventHandler(output)] diff --git a/src/mlia/cli/__init__.py b/src/mlia/cli/__init__.py new file mode 100644 index 0000000..f50778e --- /dev/null +++ b/src/mlia/cli/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""CLI module.""" diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py new file mode 100644 index 0000000..45c7c32 --- /dev/null +++ b/src/mlia/cli/commands.py @@ -0,0 +1,276 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""CLI commands module. + +This module contains functions which implement main app +functionality. + +Before running them from scripts 'logging' module should +be configured. Function 'setup_logging' from module +'mli.cli.logging' could be used for that, e.g. + +>>> from mlia.api import ExecutionContext +>>> from mlia.cli.logging import setup_logging +>>> setup_logging(verbose=True) +>>> import mlia.cli.commands as mlia +>>> mlia.all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256", + "path/to/model") +""" +import logging +from pathlib import Path +from typing import cast +from typing import List +from typing import Optional + +from mlia.api import ExecutionContext +from mlia.api import get_advice +from mlia.api import PathOrFileLike +from mlia.cli.config import get_installation_manager +from mlia.cli.options import parse_optimization_parameters +from mlia.devices.ethosu.operators import generate_supported_operators_report +from mlia.utils.console import create_section_header +from mlia.utils.types import only_one_selected + +logger = logging.getLogger(__name__) + +CONFIG = create_section_header("ML Inference Advisor configuration") + + +def all_tests( + ctx: ExecutionContext, + target_profile: str, + model: str, + optimization_type: str = "pruning,clustering", + optimization_target: str = "0.5,32", + output: Optional[PathOrFileLike] = None, + evaluate_on: Optional[List[str]] = None, +) -> None: + """Generate a full report on the input model. + + This command runs a series of tests in order to generate a + comprehensive report/advice: + + - converts the input Keras model into TFLite format + - checks the model for operator compatibility on the specified device + - applies optimizations to the model and estimates the resulting performance + on both the original and the optimized models + - generates a final report on the steps above + - provides advice on how to (possibly) improve the inference performance + + :param ctx: execution context + :param target_profile: target profile identifier. Will load appropriate parameters + from the profile.json file based on this argument. + :param model: path to the Keras model + :param optimization_type: list of the optimization techniques separated + by comma, e.g. 'pruning,clustering' + :param optimization_target: list of the corresponding targets for + the provided optimization techniques, e.g. '0.5,32' + :param output: path to the file where the report will be saved + :param evaluate_on: list of the backends to use for evaluation + + Example: + Run command for the target profile ethos-u55-256 with two model optimizations + and save report in json format locally in the file report.json + + >>> from mlia.api import ExecutionContext + >>> from mlia.cli.logging import setup_logging + >>> setup_logging() + >>> from mlia.cli.commands import all_tests + >>> all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256", + "model.h5", "pruning,clustering", "0.5,32", + output="report.json") + """ + opt_params = parse_optimization_parameters( + optimization_type, + optimization_target, + ) + + get_advice( + target_profile, + model, + "all", + optimization_targets=opt_params, + output=output, + context=ctx, + backends=evaluate_on, + ) + + +def operators( + ctx: ExecutionContext, + target_profile: str, + model: Optional[str] = None, + output: Optional[PathOrFileLike] = None, + supported_ops_report: bool = False, +) -> None: + """Print the model's operator list. + + This command checks the operator compatibility of the input model with + the specific target profile. Generates a report of the operator placement + (NPU or CPU fallback) and advice on how to improve it (if necessary). + + :param ctx: execution context + :param target_profile: target profile identifier. Will load appropriate parameters + from the profile.json file based on this argument. + :param model: path to the model, which can be TFLite or Keras + :param output: path to the file where the report will be saved + :param supported_ops_report: if True then generates supported operators + report in current directory and exits + + Example: + Run command for the target profile ethos-u55-256 and the provided + TFLite model and print report on the standard output + + >>> from mlia.api import ExecutionContext + >>> from mlia.cli.logging import setup_logging + >>> setup_logging() + >>> from mlia.cli.commands import operators + >>> operators(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256", + "model.tflite") + """ + if supported_ops_report: + generate_supported_operators_report() + logger.info("Report saved into SUPPORTED_OPS.md") + return + + if not model: + raise Exception("Model is not provided") + + get_advice( + target_profile, + model, + "operators", + output=output, + context=ctx, + ) + + +def performance( + ctx: ExecutionContext, + target_profile: str, + model: str, + output: Optional[PathOrFileLike] = None, + evaluate_on: Optional[List[str]] = None, +) -> None: + """Print the model's performance stats. + + This command estimates the inference performance of the input model + on the specified target profile, and generates a report with advice on how + to improve it. + + :param ctx: execution context + :param target_profile: target profile identifier. Will load appropriate parameters + from the profile.json file based on this argument. + :param model: path to the model, which can be TFLite or Keras + :param output: path to the file where the report will be saved + :param evaluate_on: list of the backends to use for evaluation + + Example: + Run command for the target profile ethos-u55-256 and + the provided TFLite model and print report on the standard output + + >>> from mlia.api import ExecutionContext + >>> from mlia.cli.logging import setup_logging + >>> setup_logging() + >>> from mlia.cli.commands import performance + >>> performance(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256", + "model.tflite") + """ + get_advice( + target_profile, + model, + "performance", + output=output, + context=ctx, + backends=evaluate_on, + ) + + +def optimization( + ctx: ExecutionContext, + target_profile: str, + model: str, + optimization_type: str, + optimization_target: str, + layers_to_optimize: Optional[List[str]] = None, + output: Optional[PathOrFileLike] = None, + evaluate_on: Optional[List[str]] = None, +) -> None: + """Show the performance improvements (if any) after applying the optimizations. + + This command applies the selected optimization techniques (up to the + indicated targets) and generates a report with advice on how to improve + the inference performance (if possible). + + :param ctx: execution context + :param target: target profile identifier. Will load appropriate parameters + from the profile.json file based on this argument. + :param model: path to the TFLite model + :param optimization_type: list of the optimization techniques separated + by comma, e.g. 'pruning,clustering' + :param optimization_target: list of the corresponding targets for + the provided optimization techniques, e.g. '0.5,32' + :param layers_to_optimize: list of the layers of the model which should be + optimized, if None then all layers are used + :param output: path to the file where the report will be saved + :param evaluate_on: list of the backends to use for evaluation + + Example: + Run command for the target profile ethos-u55-256 and + the provided TFLite model and print report on the standard output + + >>> from mlia.cli.logging import setup_logging + >>> setup_logging() + >>> from mlia.cli.commands import optimization + >>> optimization(ExecutionContext(working_dir="mlia_output"), + target="ethos-u55-256", + "model.tflite", "pruning", "0.5") + """ + opt_params = parse_optimization_parameters( + optimization_type, + optimization_target, + layers_to_optimize=layers_to_optimize, + ) + + get_advice( + target_profile, + model, + "optimization", + optimization_targets=opt_params, + output=output, + context=ctx, + backends=evaluate_on, + ) + + +def backend( + backend_action: str, + path: Optional[Path] = None, + download: bool = False, + name: Optional[str] = None, + i_agree_to_the_contained_eula: bool = False, + noninteractive: bool = False, +) -> None: + """Backends configuration.""" + logger.info(CONFIG) + + manager = get_installation_manager(noninteractive) + + if backend_action == "status": + manager.show_env_details() + + if backend_action == "install": + install_from_path = path is not None + + if not only_one_selected(install_from_path, download): + raise Exception( + "Please select only one action: download or " + "provide path to the backend installation" + ) + + if install_from_path: + manager.install_from(cast(Path, path), name) + + if download: + eula_agreement = not i_agree_to_the_contained_eula + manager.download_and_install(name, eula_agreement) diff --git a/src/mlia/cli/common.py b/src/mlia/cli/common.py new file mode 100644 index 0000000..54bd457 --- /dev/null +++ b/src/mlia/cli/common.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""CLI common module.""" +import argparse +from dataclasses import dataclass +from typing import Callable +from typing import List + + +@dataclass +class CommandInfo: + """Command description.""" + + func: Callable + aliases: List[str] + opt_groups: List[Callable[[argparse.ArgumentParser], None]] + is_default: bool = False + + @property + def command_name(self) -> str: + """Return command name.""" + return self.func.__name__ + + @property + def command_name_and_aliases(self) -> List[str]: + """Return list of command name and aliases.""" + return [self.command_name, *self.aliases] + + @property + def command_help(self) -> str: + """Return help message for the command.""" + assert self.func.__doc__, "Command function does not have a docstring" + func_help = self.func.__doc__.splitlines()[0].rstrip(".") + + if self.is_default: + func_help = f"{func_help} [default]" + + return func_help diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py new file mode 100644 index 0000000..838b051 --- /dev/null +++ b/src/mlia/cli/config.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Environment configuration functions.""" +import logging +from functools import lru_cache +from typing import List + +import mlia.tools.aiet_wrapper as aiet +from mlia.tools.metadata.common import DefaultInstallationManager +from mlia.tools.metadata.common import InstallationManager +from mlia.tools.metadata.corstone import get_corstone_installations + +logger = logging.getLogger(__name__) + + +def get_installation_manager(noninteractive: bool = False) -> InstallationManager: + """Return installation manager.""" + backends = get_corstone_installations() + + return DefaultInstallationManager(backends, noninteractive=noninteractive) + + +@lru_cache +def get_available_backends() -> List[str]: + """Return list of the available backends.""" + available_backends = ["Vela"] + + # Add backends using AIET + manager = get_installation_manager() + available_backends.extend( + ( + backend + for backend in aiet.supported_backends() + if manager.backend_installed(backend) + ) + ) + + return available_backends + + +# List of mutually exclusive Corstone backends ordered by priority +_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300") + + +def get_default_backends() -> List[str]: + """Get default backends for evaluation.""" + backends = get_available_backends() + + # Filter backends to only include one Corstone backend + for corstone in _CORSTONE_EXCLUSIVE_PRIORITY: + if corstone in backends: + backends = [ + backend + for backend in backends + if backend == corstone or backend not in _CORSTONE_EXCLUSIVE_PRIORITY + ] + break + + return backends + + +def is_corstone_backend(backend: str) -> bool: + """Check if the given backend is a Corstone backend.""" + return backend in _CORSTONE_EXCLUSIVE_PRIORITY diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py new file mode 100644 index 0000000..81d5a15 --- /dev/null +++ b/src/mlia/cli/helpers.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for various helper classes.""" +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +from mlia.cli.options import get_target_profile_opts +from mlia.core.helpers import ActionResolver +from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.nn.tensorflow.utils import is_keras_model +from mlia.utils.types import is_list_of + + +class CLIActionResolver(ActionResolver): + """Helper class for generating cli commands.""" + + def __init__(self, args: Dict[str, Any]) -> None: + """Init action resolver.""" + self.args = args + + @staticmethod + def _general_optimization_command(model_path: Optional[str]) -> List[str]: + """Return general optimization command description.""" + keras_note = [] + if model_path is None or not is_keras_model(model_path): + model_path = "/path/to/keras_model" + keras_note = ["Note: you will need a Keras model for that."] + + return [ + *keras_note, + "For example: mlia optimization --optimization-type " + f"pruning,clustering --optimization-target 0.5,32 {model_path}", + "For more info: mlia optimization --help", + ] + + @staticmethod + def _specific_optimization_command( + model_path: str, + device_opts: str, + opt_settings: List[OptimizationSettings], + ) -> List[str]: + """Return specific optimization command description.""" + opt_types = ",".join(opt.optimization_type for opt in opt_settings) + opt_targs = ",".join(str(opt.optimization_target) for opt in opt_settings) + + return [ + "For more info: mlia optimization --help", + "Optimization command: " + f"mlia optimization --optimization-type {opt_types} " + f"--optimization-target {opt_targs}{device_opts} {model_path}", + ] + + def apply_optimizations(self, **kwargs: Any) -> List[str]: + """Return command details for applying optimizations.""" + model_path, device_opts = self._get_model_and_device_opts() + + if (opt_settings := kwargs.pop("opt_settings", None)) is None: + return self._general_optimization_command(model_path) + + if is_list_of(opt_settings, OptimizationSettings) and model_path: + return self._specific_optimization_command( + model_path, device_opts, opt_settings + ) + + return [] + + def supported_operators_info(self) -> List[str]: + """Return command details for generating supported ops report.""" + return [ + "For guidance on supported operators, run: mlia operators " + "--supported-ops-report", + ] + + def check_performance(self) -> List[str]: + """Return command details for checking performance.""" + model_path, device_opts = self._get_model_and_device_opts() + if not model_path: + return [] + + return [ + "Check the estimated performance by running the following command: ", + f"mlia performance{device_opts} {model_path}", + ] + + def check_operator_compatibility(self) -> List[str]: + """Return command details for op compatibility.""" + model_path, device_opts = self._get_model_and_device_opts() + if not model_path: + return [] + + return [ + "Try running the following command to verify that:", + f"mlia operators{device_opts} {model_path}", + ] + + def operator_compatibility_details(self) -> List[str]: + """Return command details for op compatibility.""" + return ["For more details, run: mlia operators --help"] + + def optimization_details(self) -> List[str]: + """Return command details for optimization.""" + return ["For more info, see: mlia optimization --help"] + + def _get_model_and_device_opts( + self, separate_device_opts: bool = True + ) -> Tuple[Optional[str], str]: + """Get model and device options.""" + device_opts = " ".join(get_target_profile_opts(self.args)) + if separate_device_opts and device_opts: + device_opts = f" {device_opts}" + + model_path = self.args.get("model") + return model_path, device_opts diff --git a/src/mlia/cli/logging.py b/src/mlia/cli/logging.py new file mode 100644 index 0000000..c5fc7bd --- /dev/null +++ b/src/mlia/cli/logging.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""CLI logging configuration.""" +import logging +import sys +from pathlib import Path +from typing import List +from typing import Optional +from typing import Union + +from mlia.utils.logging import attach_handlers +from mlia.utils.logging import create_log_handler +from mlia.utils.logging import LogFilter + + +_CONSOLE_DEBUG_FORMAT = "%(name)s - %(message)s" +_FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + +def setup_logging( + logs_dir: Optional[Union[str, Path]] = None, + verbose: bool = False, + log_filename: str = "mlia.log", +) -> None: + """Set up logging. + + MLIA uses module 'logging' when it needs to produce output. + + :param logs_dir: path to the directory where application will save logs with + debug information. If the path is not provided then no log files will + be created during execution + :param verbose: enable extended logging for the tools loggers + :param log_filename: name of the log file in the logs directory + """ + mlia_logger, *tools_loggers = [ + logging.getLogger(logger_name) + for logger_name in ["mlia", "tensorflow", "py.warnings"] + ] + + # enable debug output, actual message filtering depends on + # the provided parameters and being done on the handlers level + mlia_logger.setLevel(logging.DEBUG) + + mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose) + attach_handlers(mlia_handlers, [mlia_logger]) + + tools_handlers = _get_tools_handlers(logs_dir, log_filename, verbose) + attach_handlers(tools_handlers, tools_loggers) + + +def _get_mlia_handlers( + logs_dir: Optional[Union[str, Path]], + log_filename: str, + verbose: bool, +) -> List[logging.Handler]: + """Get handlers for the MLIA loggers.""" + result = [] + stdout_handler = create_log_handler( + stream=sys.stdout, + log_level=logging.INFO, + ) + result.append(stdout_handler) + + if verbose: + mlia_verbose_handler = create_log_handler( + stream=sys.stdout, + log_level=logging.DEBUG, + log_format=_CONSOLE_DEBUG_FORMAT, + log_filter=LogFilter.equals(logging.DEBUG), + ) + result.append(mlia_verbose_handler) + + if logs_dir: + mlia_file_handler = create_log_handler( + file_path=_get_log_file(logs_dir, log_filename), + log_level=logging.DEBUG, + log_format=_FILE_DEBUG_FORMAT, + log_filter=LogFilter.skip(logging.INFO), + delay=True, + ) + result.append(mlia_file_handler) + + return result + + +def _get_tools_handlers( + logs_dir: Optional[Union[str, Path]], + log_filename: str, + verbose: bool, +) -> List[logging.Handler]: + """Get handler for the tools loggers.""" + result = [] + if verbose: + verbose_stdout_handler = create_log_handler( + stream=sys.stdout, + log_level=logging.DEBUG, + log_format=_CONSOLE_DEBUG_FORMAT, + ) + result.append(verbose_stdout_handler) + + if logs_dir: + file_handler = create_log_handler( + file_path=_get_log_file(logs_dir, log_filename), + log_level=logging.DEBUG, + log_format=_FILE_DEBUG_FORMAT, + delay=True, + ) + result.append(file_handler) + + return result + + +def _get_log_file(logs_dir: Union[str, Path], log_filename: str) -> Path: + """Get the log file path.""" + logs_dir_path = Path(logs_dir) + logs_dir_path.mkdir(exist_ok=True) + return logs_dir_path / log_filename diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py new file mode 100644 index 0000000..33fcdeb --- /dev/null +++ b/src/mlia/cli/main.py @@ -0,0 +1,280 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""CLI main entry point.""" +import argparse +import logging +import sys +from inspect import signature +from pathlib import Path +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +from mlia import __version__ +from mlia.cli.commands import all_tests +from mlia.cli.commands import backend +from mlia.cli.commands import operators +from mlia.cli.commands import optimization +from mlia.cli.commands import performance +from mlia.cli.common import CommandInfo +from mlia.cli.helpers import CLIActionResolver +from mlia.cli.logging import setup_logging +from mlia.cli.options import add_backend_options +from mlia.cli.options import add_custom_supported_operators_options +from mlia.cli.options import add_debug_options +from mlia.cli.options import add_evaluation_options +from mlia.cli.options import add_keras_model_options +from mlia.cli.options import add_multi_optimization_options +from mlia.cli.options import add_optional_tflite_model_options +from mlia.cli.options import add_output_options +from mlia.cli.options import add_target_options +from mlia.cli.options import add_tflite_model_options +from mlia.core.context import ExecutionContext + + +logger = logging.getLogger(__name__) + +INFO_MESSAGE = f""" +ML Inference Advisor {__version__} + +Help the design and optimization of neural network models for efficient inference on a target CPU, GPU and NPU + +Supported targets: + + - Ethos-U55 + - Ethos-U65 + +""".strip() + + +def get_commands() -> List[CommandInfo]: + """Return commands configuration.""" + return [ + CommandInfo( + all_tests, + ["all"], + [ + add_target_options, + add_keras_model_options, + add_multi_optimization_options, + add_output_options, + add_debug_options, + add_evaluation_options, + ], + True, + ), + CommandInfo( + operators, + ["ops"], + [ + add_target_options, + add_optional_tflite_model_options, + add_output_options, + add_custom_supported_operators_options, + add_debug_options, + ], + ), + CommandInfo( + performance, + ["perf"], + [ + add_target_options, + add_tflite_model_options, + add_output_options, + add_debug_options, + add_evaluation_options, + ], + ), + CommandInfo( + optimization, + ["opt"], + [ + add_target_options, + add_keras_model_options, + add_multi_optimization_options, + add_output_options, + add_debug_options, + add_evaluation_options, + ], + ), + CommandInfo( + backend, + [], + [ + add_backend_options, + add_debug_options, + ], + ), + ] + + +def get_default_command() -> Optional[str]: + """Get name of the default command.""" + commands = get_commands() + + marked_as_default = [cmd.command_name for cmd in commands if cmd.is_default] + assert len(marked_as_default) <= 1, "Only one command could be marked as default" + + return next(iter(marked_as_default), None) + + +def get_possible_command_names() -> List[str]: + """Get all possible command names including aliases.""" + return [ + name_or_alias + for cmd in get_commands() + for name_or_alias in cmd.command_name_and_aliases + ] + + +def init_commands(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Init cli subcommands.""" + subparsers = parser.add_subparsers(title="Commands", dest="command") + subparsers.required = True + + for command in get_commands(): + command_parser = subparsers.add_parser( + command.command_name, + aliases=command.aliases, + help=command.command_help, + allow_abbrev=False, + ) + command_parser.set_defaults(func=command.func) + for opt_group in command.opt_groups: + opt_group(command_parser) + + return parser + + +def setup_context( + args: argparse.Namespace, context_var_name: str = "ctx" +) -> Tuple[ExecutionContext, Dict]: + """Set up context and resolve function parameters.""" + ctx = ExecutionContext( + working_dir=args.working_dir, + verbose="verbose" in args and args.verbose, + action_resolver=CLIActionResolver(vars(args)), + ) + + # these parameters should not be passed into command function + skipped_params = ["func", "command", "working_dir", "verbose"] + + # pass these parameters only if command expects them + expected_params = [context_var_name] + func_params = signature(args.func).parameters + + params = {context_var_name: ctx, **vars(args)} + + func_args = { + param_name: param_value + for param_name, param_value in params.items() + if param_name not in skipped_params + and (param_name not in expected_params or param_name in func_params) + } + + return (ctx, func_args) + + +def run_command(args: argparse.Namespace) -> int: + """Run command.""" + ctx, func_args = setup_context(args) + setup_logging(ctx.logs_path, ctx.verbose) + + logger.debug( + "*** This is the beginning of the command '%s' execution ***", args.command + ) + + try: + logger.info(INFO_MESSAGE) + + args.func(**func_args) + return 0 + except KeyboardInterrupt: + logger.error("Execution has been interrupted") + except Exception as err: # pylint: disable=broad-except + logger.error( + "\nExecution finished with error: %s", + err, + exc_info=err if ctx.verbose else None, + ) + + err_advice_message = ( + f"Please check the log files in the {ctx.logs_path} for more details" + ) + if not ctx.verbose: + err_advice_message += ", or enable verbose mode" + + logger.error(err_advice_message) + + return 1 + + +def init_common_parser() -> argparse.ArgumentParser: + """Init common parser.""" + parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) + parser.add_argument( + "--working-dir", + default=f"{Path.cwd() / 'mlia_output'}", + help="Path to the directory where MLIA will store logs, " + "models, etc. (default: %(default)s)", + ) + + return parser + + +def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Init subcommand parser.""" + parser = argparse.ArgumentParser( + description=INFO_MESSAGE, + formatter_class=argparse.RawDescriptionHelpFormatter, + parents=[parent], + add_help=False, + allow_abbrev=False, + ) + parser.add_argument( + "-h", + "--help", + action="help", + default=argparse.SUPPRESS, + help="Show this help message and exit", + ) + parser.add_argument( + "-v", + "--version", + action="version", + version=f"%(prog)s {__version__}", + help="Show program's version number and exit", + ) + + return parser + + +def add_default_command_if_needed(args: List[str]) -> None: + """Add default command to the list of the arguments if needed.""" + default_command = get_default_command() + + if default_command and len(args) > 0: + commands = get_possible_command_names() + help_or_version = ["-h", "--help", "-v", "--version"] + + command_is_missing = args[0] not in [*commands, *help_or_version] + if command_is_missing: + args.insert(0, default_command) + + +def main(argv: Optional[List[str]] = None) -> int: + """Entry point of the application.""" + common_parser = init_common_parser() + subcommand_parser = init_subcommand_parser(common_parser) + init_commands(subcommand_parser) + + common_args, subcommand_args = common_parser.parse_known_args(argv) + add_default_command_if_needed(subcommand_args) + + args = subcommand_parser.parse_args(subcommand_args, common_args) + return run_command(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py new file mode 100644 index 0000000..dc5cb73 --- /dev/null +++ b/src/mlia/cli/options.py @@ -0,0 +1,280 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for the CLI options.""" +import argparse +from pathlib import Path +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional + +from mlia.cli.config import get_available_backends +from mlia.cli.config import get_default_backends +from mlia.cli.config import is_corstone_backend +from mlia.utils.filesystem import get_supported_profile_names +from mlia.utils.types import is_number + + +def add_target_options(parser: argparse.ArgumentParser) -> None: + """Add target specific options.""" + target_profiles = get_supported_profile_names() + + default_target_profile = None + default_help = "" + if target_profiles: + default_target_profile = target_profiles[0] + default_help = " (default: %(default)s)" + + target_group = parser.add_argument_group("target options") + target_group.add_argument( + "--target-profile", + choices=target_profiles, + default=default_target_profile, + help="Target profile that will set the target options " + "such as target, mac value, memory mode, etc. " + f"For the values associated with each target profile " + f" please refer to the documenation {default_help}.", + ) + + +def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None: + """Add optimization specific options.""" + multi_optimization_group = parser.add_argument_group("optimization options") + + multi_optimization_group.add_argument( + "--optimization-type", + default="pruning,clustering", + help="List of the optimization types separated by comma (default: %(default)s)", + ) + multi_optimization_group.add_argument( + "--optimization-target", + default="0.5,32", + help="""List of the optimization targets separated by comma, + (for pruning this is sparsity between (0,1), + for clustering this is the number of clusters (positive integer)) + (default: %(default)s)""", + ) + + +def add_optional_tflite_model_options(parser: argparse.ArgumentParser) -> None: + """Add optional model specific options.""" + model_group = parser.add_argument_group("TFLite model options") + # make model parameter optional + model_group.add_argument("model", nargs="?", help="TFLite model (optional)") + + +def add_tflite_model_options(parser: argparse.ArgumentParser) -> None: + """Add model specific options.""" + model_group = parser.add_argument_group("TFLite model options") + model_group.add_argument("model", help="TFLite model") + + +def add_output_options(parser: argparse.ArgumentParser) -> None: + """Add output specific options.""" + valid_extensions = ["csv", "json"] + + def check_extension(filename: str) -> str: + """Check extension of the provided file.""" + suffix = Path(filename).suffix + if suffix.startswith("."): + suffix = suffix[1:] + + if suffix.lower() not in valid_extensions: + parser.error(f"Unsupported format '{suffix}'") + + return filename + + output_group = parser.add_argument_group("output options") + output_group.add_argument( + "--output", + type=check_extension, + help=( + "Name of the file where report will be saved. " + "Report format is automatically detected based on the file extension. " + f"Supported formats are: {', '.join(valid_extensions)}" + ), + ) + + +def add_debug_options(parser: argparse.ArgumentParser) -> None: + """Add debug options.""" + debug_group = parser.add_argument_group("debug options") + debug_group.add_argument( + "--verbose", default=False, action="store_true", help="Produce verbose output" + ) + + +def add_keras_model_options(parser: argparse.ArgumentParser) -> None: + """Add model specific options.""" + model_group = parser.add_argument_group("Keras model options") + model_group.add_argument("model", help="Keras model") + + +def add_custom_supported_operators_options(parser: argparse.ArgumentParser) -> None: + """Add custom options for the command 'operators'.""" + parser.add_argument( + "--supported-ops-report", + action="store_true", + default=False, + help=( + "Generate the SUPPORTED_OPS.md file in the " + "current working directory and exit" + ), + ) + + +def add_backend_options(parser: argparse.ArgumentParser) -> None: + """Add options for the backends configuration.""" + + def valid_directory(param: str) -> Path: + """Check if passed string is a valid directory path.""" + if not (dir_path := Path(param)).is_dir(): + parser.error(f"Invalid directory path {param}") + + return dir_path + + subparsers = parser.add_subparsers(title="Backend actions", dest="backend_action") + subparsers.required = True + + install_subparser = subparsers.add_parser( + "install", help="Install backend", allow_abbrev=False + ) + install_type_group = install_subparser.add_mutually_exclusive_group() + install_type_group.required = True + install_type_group.add_argument( + "--path", type=valid_directory, help="Path to the installed backend" + ) + install_type_group.add_argument( + "--download", + default=False, + action="store_true", + help="Download and install backend", + ) + install_subparser.add_argument( + "--i-agree-to-the-contained-eula", + default=False, + action="store_true", + help=argparse.SUPPRESS, + ) + install_subparser.add_argument( + "--noninteractive", + default=False, + action="store_true", + help="Non interactive mode with automatic confirmation of every action", + ) + install_subparser.add_argument( + "name", + nargs="?", + help="Name of the backend to install", + ) + + subparsers.add_parser("status", help="Show backends status") + + +def add_evaluation_options(parser: argparse.ArgumentParser) -> None: + """Add evaluation options.""" + available_backends = get_available_backends() + default_backends = get_default_backends() + + def only_one_corstone_checker() -> Callable: + """ + Return a callable to check that only one Corstone backend is passed. + + Raises an exception when more than one Corstone backend is passed. + """ + num_corstones = 0 + + def check(backend: str) -> str: + """Count Corstone backends and raise an exception if more than one.""" + nonlocal num_corstones + if is_corstone_backend(backend): + num_corstones = num_corstones + 1 + if num_corstones > 1: + raise argparse.ArgumentTypeError( + "There must be only one Corstone backend in the argument list." + ) + return backend + + return check + + evaluation_group = parser.add_argument_group("evaluation options") + evaluation_group.add_argument( + "--evaluate-on", + help="Backends to use for evaluation (default: %(default)s)", + nargs="*", + choices=available_backends, + default=default_backends, + type=only_one_corstone_checker(), + ) + + +def parse_optimization_parameters( + optimization_type: str, + optimization_target: str, + sep: str = ",", + layers_to_optimize: Optional[List[str]] = None, +) -> List[Dict[str, Any]]: + """Parse provided optimization parameters.""" + if not optimization_type: + raise Exception("Optimization type is not provided") + + if not optimization_target: + raise Exception("Optimization target is not provided") + + opt_types = optimization_type.split(sep) + opt_targets = optimization_target.split(sep) + + if len(opt_types) != len(opt_targets): + raise Exception("Wrong number of optimization targets and types") + + non_numeric_targets = [ + opt_target for opt_target in opt_targets if not is_number(opt_target) + ] + if len(non_numeric_targets) > 0: + raise Exception("Non numeric value for the optimization target") + + optimizer_params = [ + { + "optimization_type": opt_type.strip(), + "optimization_target": float(opt_target), + "layers_to_optimize": layers_to_optimize, + } + for opt_type, opt_target in zip(opt_types, opt_targets) + ] + + return optimizer_params + + +def get_target_profile_opts(device_args: Optional[Dict]) -> List[str]: + """Get non default values passed as parameters for the target profile.""" + if not device_args: + return [] + + dummy_parser = argparse.ArgumentParser() + add_target_options(dummy_parser) + args = dummy_parser.parse_args([]) + + params_name = { + action.dest: param_name + for param_name, action in dummy_parser._option_string_actions.items() # pylint: disable=protected-access + } + + non_default = [ + arg_name + for arg_name, arg_value in device_args.items() + if arg_name in args and vars(args)[arg_name] != arg_value + ] + + def construct_param(name: str, value: Any) -> List[str]: + """Construct parameter.""" + if isinstance(value, list): + return [str(item) for v in value for item in [name, v]] + + return [name, str(value)] + + return [ + item + for name in non_default + for item in construct_param(params_name[name], device_args[name]) + ] diff --git a/src/mlia/core/__init__.py b/src/mlia/core/__init__.py new file mode 100644 index 0000000..49b1830 --- /dev/null +++ b/src/mlia/core/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Core module. + +Core module contains the main components that are used in the workflow of +ML Inference Advisor: + - data collectors + - data analyzers + - advice producers + - event publishers + - event handlers + +The workflow of ML Inference Advisor consists of 3 stages: + - data collection + - data analysis + - advice generation + +Data is being passed from one stage to another via workflow executor. +Results (collected data, analyzed data, advice, etc) are being published via +publish/subscribe mechanishm. +""" diff --git a/src/mlia/core/_typing.py b/src/mlia/core/_typing.py new file mode 100644 index 0000000..bda995c --- /dev/null +++ b/src/mlia/core/_typing.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for custom type hints.""" +from pathlib import Path +from typing import Literal +from typing import TextIO +from typing import Union + + +FileLike = TextIO +PathOrFileLike = Union[str, Path, FileLike] +OutputFormat = Literal["plain_text", "csv", "json"] diff --git a/src/mlia/core/advice_generation.py b/src/mlia/core/advice_generation.py new file mode 100644 index 0000000..76cc1f2 --- /dev/null +++ b/src/mlia/core/advice_generation.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for advice generation.""" +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from functools import wraps +from typing import Any +from typing import Callable +from typing import List +from typing import Union + +from mlia.core.common import AdviceCategory +from mlia.core.common import DataItem +from mlia.core.events import SystemEvent +from mlia.core.mixins import ContextMixin + + +@dataclass +class Advice: + """Base class for the advice.""" + + messages: List[str] + + +@dataclass +class AdviceEvent(SystemEvent): + """Advice event. + + This event is published for every produced advice. + + :param advice: Advice instance + """ + + advice: Advice + + +class AdviceProducer(ABC): + """Base class for the advice producer. + + Producer has two methods for advice generation: + - produce_advice - used to generate advice based on provided + data (analyzed data item from analyze stage) + - get_advice - used for getting generated advice + + Advice producers that have predefined advice could skip + implementation of produce_advice method. + """ + + @abstractmethod + def produce_advice(self, data_item: DataItem) -> None: + """Process data item and produce advice. + + :param data_item: piece of data that could be used + for advice generation + """ + + @abstractmethod + def get_advice(self) -> Union[Advice, List[Advice]]: + """Get produced advice.""" + + +class ContextAwareAdviceProducer(AdviceProducer, ContextMixin): + """Context aware advice producer. + + This class makes easier access to the Context object. Context object could + be automatically injected during workflow configuration. + """ + + +class FactBasedAdviceProducer(ContextAwareAdviceProducer): + """Advice producer based on provided facts. + + This is an utility class that maintain list of generated Advice instances. + """ + + def __init__(self) -> None: + """Init advice producer.""" + self.advice: List[Advice] = [] + + def get_advice(self) -> Union[Advice, List[Advice]]: + """Get produced advice.""" + return self.advice + + def add_advice(self, messages: List[str]) -> None: + """Add advice.""" + self.advice.append(Advice(messages)) + + +def advice_category(*categories: AdviceCategory) -> Callable: + """Filter advice generation handler by advice category.""" + + def wrapper(handler: Callable) -> Callable: + """Wrap data handler.""" + + @wraps(handler) + def check_category(self: Any, *args: Any, **kwargs: Any) -> Any: + """Check if handler can produce advice for the requested category.""" + if not self.context.any_category_enabled(*categories): + return + + handler(self, *args, **kwargs) + + return check_category + + return wrapper diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py new file mode 100644 index 0000000..868d0c7 --- /dev/null +++ b/src/mlia/core/advisor.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Inference advisor module.""" +from abc import abstractmethod + +from mlia.core.common import NamedEntity +from mlia.core.context import Context +from mlia.core.workflow import WorkflowExecutor + + +class InferenceAdvisor(NamedEntity): + """Base class for inference advisors.""" + + @abstractmethod + def configure(self, context: Context) -> WorkflowExecutor: + """Configure advisor execution.""" + + def run(self, context: Context) -> None: + """Run inference advisor.""" + executor = self.configure(context) + executor.run() diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py new file mode 100644 index 0000000..5fbad42 --- /dev/null +++ b/src/mlia/core/common.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Common module. + +This module contains common interfaces/classess shared across +core module. +""" +from abc import ABC +from abc import abstractmethod +from enum import Enum +from typing import Any + +# This type is used as type alias for the items which are being passed around +# in advisor workflow. There are no restrictions on the type of the +# object. This alias used only to emphasize the nature of the input/output +# arguments. +DataItem = Any + + +class AdviceCategory(Enum): + """Advice category. + + Enumeration of advice categories supported by ML Inference Advisor. + """ + + OPERATORS = 1 + PERFORMANCE = 2 + OPTIMIZATION = 3 + ALL = 4 + + @classmethod + def from_string(cls, value: str) -> "AdviceCategory": + """Resolve enum value from string value.""" + category_names = [item.name for item in AdviceCategory] + if not value or value.upper() not in category_names: + raise Exception(f"Invalid advice category {value}") + + return AdviceCategory[value.upper()] + + +class NamedEntity(ABC): + """Entity with a name and description.""" + + @classmethod + @abstractmethod + def name(cls) -> str: + """Return name of the entity.""" diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py new file mode 100644 index 0000000..8b3dd2c --- /dev/null +++ b/src/mlia/core/context.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Context module. + +This module contains functionality related to the Context. +Context is an object that describes advisor working environment +and requested behavior (advice categories, input configuration +parameters). +""" +import logging +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from typing import Any +from typing import List +from typing import Mapping +from typing import Optional +from typing import Union + +from mlia.core.common import AdviceCategory +from mlia.core.events import DefaultEventPublisher +from mlia.core.events import EventHandler +from mlia.core.events import EventPublisher +from mlia.core.helpers import ActionResolver +from mlia.core.helpers import APIActionResolver + +logger = logging.getLogger(__name__) + + +class Context(ABC): + """Abstract class for the execution context.""" + + @abstractmethod + def get_model_path(self, model_filename: str) -> Path: + """Return path for the intermediate/optimized models. + + During workflow execution different parts of the advisor + require creating intermediate files for models. + + This method allows to provide paths where those models + could be saved. + + :param model_filename: filename of the model + """ + + @property + @abstractmethod + def event_publisher(self) -> EventPublisher: + """Return event publisher.""" + + @property + @abstractmethod + def event_handlers(self) -> Optional[List[EventHandler]]: + """Return list of the event_handlers.""" + + @property + @abstractmethod + def advice_category(self) -> Optional[AdviceCategory]: + """Return advice category.""" + + @property + @abstractmethod + def config_parameters(self) -> Optional[Mapping[str, Any]]: + """Return configuration parameters.""" + + @property + @abstractmethod + def action_resolver(self) -> ActionResolver: + """Return action resolver.""" + + @abstractmethod + def update( + self, + *, + advice_category: AdviceCategory, + event_handlers: List[EventHandler], + config_parameters: Mapping[str, Any], + ) -> None: + """Update context parameters.""" + + def category_enabled(self, category: AdviceCategory) -> bool: + """Check if category enabled.""" + return category == self.advice_category + + def any_category_enabled(self, *categories: AdviceCategory) -> bool: + """Return true if any category is enabled.""" + return self.advice_category in categories + + def register_event_handlers(self) -> None: + """Register event handlers.""" + self.event_publisher.register_event_handlers(self.event_handlers) + + +class ExecutionContext(Context): + """Execution context.""" + + def __init__( + self, + *, + advice_category: Optional[AdviceCategory] = None, + config_parameters: Optional[Mapping[str, Any]] = None, + working_dir: Optional[Union[str, Path]] = None, + event_handlers: Optional[List[EventHandler]] = None, + event_publisher: Optional[EventPublisher] = None, + verbose: bool = False, + logs_dir: str = "logs", + models_dir: str = "models", + action_resolver: Optional[ActionResolver] = None, + ) -> None: + """Init execution context. + + :param advice_category: requested advice category + :param config_parameters: dictionary like object with input parameters + :param working_dir: path to the directory that will be used as a place + to store temporary files, logs, models. If not provided then + current working directory will be used instead + :param event_handlers: optional list of event handlers + :param event_publisher: optional event publisher instance. If not provided + then default implementation of event publisher will be used + :param verbose: enable verbose output + :param logs_dir: name of the directory inside working directory where + log files will be stored + :param models_dir: name of the directory inside working directory where + temporary models will be stored + :param action_resolver: instance of the action resolver that could make + advice actionable + """ + self._advice_category = advice_category + self._config_parameters = config_parameters + + self._working_dir_path = Path.cwd() + if working_dir: + self._working_dir_path = Path(working_dir) + self._working_dir_path.mkdir(exist_ok=True) + + self._event_handlers = event_handlers + self._event_publisher = event_publisher or DefaultEventPublisher() + self.verbose = verbose + self.logs_dir = logs_dir + self.models_dir = models_dir + self._action_resolver = action_resolver or APIActionResolver() + + @property + def advice_category(self) -> Optional[AdviceCategory]: + """Return advice category.""" + return self._advice_category + + @advice_category.setter + def advice_category(self, advice_category: AdviceCategory) -> None: + """Setter for the advice category.""" + self._advice_category = advice_category + + @property + def config_parameters(self) -> Optional[Mapping[str, Any]]: + """Return configuration parameters.""" + return self._config_parameters + + @config_parameters.setter + def config_parameters(self, config_parameters: Optional[Mapping[str, Any]]) -> None: + """Setter for the configuration parameters.""" + self._config_parameters = config_parameters + + @property + def event_handlers(self) -> Optional[List[EventHandler]]: + """Return list of the event handlers.""" + return self._event_handlers + + @event_handlers.setter + def event_handlers(self, event_handlers: List[EventHandler]) -> None: + """Setter for the event handlers.""" + self._event_handlers = event_handlers + + @property + def event_publisher(self) -> EventPublisher: + """Return event publisher.""" + return self._event_publisher + + @property + def action_resolver(self) -> ActionResolver: + """Return action resolver.""" + return self._action_resolver + + def get_model_path(self, model_filename: str) -> Path: + """Return path for the model.""" + models_dir_path = self._working_dir_path / self.models_dir + models_dir_path.mkdir(exist_ok=True) + + return models_dir_path / model_filename + + @property + def logs_path(self) -> Path: + """Return path to the logs directory.""" + return self._working_dir_path / self.logs_dir + + def update( + self, + *, + advice_category: AdviceCategory, + event_handlers: List[EventHandler], + config_parameters: Mapping[str, Any], + ) -> None: + """Update context parameters.""" + self._advice_category = advice_category + self._event_handlers = event_handlers + self._config_parameters = config_parameters + + def __str__(self) -> str: + """Return string representation.""" + category = ( + "" if self.advice_category is None else self.advice_category.name + ) + + return ( + f"ExecutionContext: working_dir={self._working_dir_path}, " + f"advice_category={category}, " + f"config_parameters={self.config_parameters}, " + f"verbose={self.verbose}" + ) diff --git a/src/mlia/core/data_analysis.py b/src/mlia/core/data_analysis.py new file mode 100644 index 0000000..6adb41e --- /dev/null +++ b/src/mlia/core/data_analysis.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for data analysis.""" +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from typing import List + +from mlia.core.common import DataItem +from mlia.core.mixins import ContextMixin + + +class DataAnalyzer(ABC): + """Base class for the data analysis. + + Purpose of this class is to extract valuable data out of + collected data which could be used for advice generation. + + This process consists of two steps: + - analyze every item of the collected data + - get analyzed data + """ + + @abstractmethod + def analyze_data(self, data_item: DataItem) -> None: + """Analyze data. + + :param data_item: item of the collected data + """ + + @abstractmethod + def get_analyzed_data(self) -> List[DataItem]: + """Get analyzed data.""" + + +class ContextAwareDataAnalyzer(DataAnalyzer, ContextMixin): + """Context aware data analyzer. + + This class makes easier access to the Context object. Context object could + be automatically injected during workflow configuration. + """ + + +@dataclass +class Fact: + """Base class for the facts. + + Fact represents some piece of knowledge about collected + data. + """ + + +class FactExtractor(ContextAwareDataAnalyzer): + """Data analyzer based on extracting facts. + + Utility class that makes fact extraction easier. + Class maintains list of the extracted facts. + """ + + def __init__(self) -> None: + """Init fact extractor.""" + self.facts: List[Fact] = [] + + def get_analyzed_data(self) -> List[DataItem]: + """Return list of the collected facts.""" + return self.facts + + def add_fact(self, fact: Fact) -> None: + """Add fact.""" + self.facts.append(fact) diff --git a/src/mlia/core/data_collection.py b/src/mlia/core/data_collection.py new file mode 100644 index 0000000..43b6d1c --- /dev/null +++ b/src/mlia/core/data_collection.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for data collection. + +This module contains base classes for the first stage +of the ML Inference Advisor workflow - data collection. +""" +from abc import abstractmethod + +from mlia.core.common import DataItem +from mlia.core.common import NamedEntity +from mlia.core.mixins import ContextMixin +from mlia.core.mixins import ParameterResolverMixin + + +class DataCollector(NamedEntity): + """Base class for the data collection. + + Data collection is the first step in the process of the advice + generation. + + Different implementations of this class can provide various + information about model or device. This information is being used + at later stages. + """ + + @abstractmethod + def collect_data(self) -> DataItem: + """Collect data.""" + + +class ContextAwareDataCollector(DataCollector, ContextMixin, ParameterResolverMixin): + """Context aware data collector. + + This class makes easier access to the Context object. Context object could + be automatically injected during workflow configuration. + """ diff --git a/src/mlia/core/errors.py b/src/mlia/core/errors.py new file mode 100644 index 0000000..7d6beb1 --- /dev/null +++ b/src/mlia/core/errors.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""MLIA exceptions module.""" + + +class ConfigurationError(Exception): + """Configuration error.""" + + +class FunctionalityNotSupportedError(Exception): + """Functionality is not supported error.""" + + def __init__(self, reason: str, description: str) -> None: + """Init exception.""" + super().__init__(f"{reason}: {description}") + + self.reason = reason + self.description = description diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py new file mode 100644 index 0000000..10aec86 --- /dev/null +++ b/src/mlia/core/events.py @@ -0,0 +1,455 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for the events and related functionality. + +This module represents one of the main component of the workflow - +events publishing and provides a way for delivering results to the +calling application. + +Each component of the workflow can generate events of specific type. +Application can subscribe and react to those events. +""" +import traceback +import uuid +from abc import ABC +from abc import abstractmethod +from contextlib import contextmanager +from dataclasses import asdict +from dataclasses import dataclass +from dataclasses import field +from functools import singledispatchmethod +from typing import Any +from typing import Dict +from typing import Generator +from typing import List +from typing import Optional +from typing import Tuple + +from mlia.core.common import DataItem + + +@dataclass +class Event: + """Base class for the events. + + This class is used as a root node of the events class hierarchy. + """ + + event_id: str = field(init=False) + + def __post_init__(self) -> None: + """Generate unique ID for the event.""" + self.event_id = str(uuid.uuid4()) + + def compare_without_id(self, other: "Event") -> bool: + """Compare two events without event_id field.""" + if not isinstance(other, Event) or self.__class__ != other.__class__: + return False + + self_as_dict = asdict(self) + self_as_dict.pop("event_id") + + other_as_dict = asdict(other) + other_as_dict.pop("event_id") + + return self_as_dict == other_as_dict + + +@dataclass +class ChildEvent(Event): + """Child event. + + This class could be used to link event with the parent event. + """ + + parent_event_id: str + + +@dataclass +class ActionStartedEvent(Event): + """Action started event. + + This event is published when some action has been started. + """ + + action_type: str + params: Optional[Dict] = None + + +@dataclass +class SubActionEvent(ChildEvent): + """SubAction event. + + This event could be used to represent some action during parent action. + """ + + action_type: str + params: Optional[Dict] = None + + +@dataclass +class ActionFinishedEvent(ChildEvent): + """Action finished event. + + This event is published when some action has been finished. + """ + + +@dataclass +class SystemEvent(Event): + """System event. + + System event class represents events that published by components + of the core module. Most common example is an workflow executor + that publishes number of system events for starting/completion + of different stages/workflow. + + Events that published by components outside of core module should not + use this class as base class. + """ + + +@dataclass +class ExecutionStartedEvent(SystemEvent): + """Execution started event. + + This event is published when workflow execution started. + """ + + +@dataclass +class ExecutionFinishedEvent(SystemEvent): + """Execution finished event. + + This event is published when workflow execution finished. + """ + + +@dataclass +class ExecutionFailedEvent(SystemEvent): + """Execution failed event.""" + + err: Exception + + +@dataclass +class DataCollectionStageStartedEvent(SystemEvent): + """Data collection stage started. + + This event is published when data collection stage started. + """ + + +@dataclass +class DataCollectorSkippedEvent(SystemEvent): + """Data collector skipped event. + + This event is published when particular data collector can + not provide data for the provided parameters. + """ + + data_collector: str + reason: str + + +@dataclass +class DataCollectionStageFinishedEvent(SystemEvent): + """Data collection stage finished. + + This event is published when data collection stage finished. + """ + + +@dataclass +class DataAnalysisStageStartedEvent(SystemEvent): + """Data analysis stage started. + + This event is published when data analysis stage started. + """ + + +@dataclass +class DataAnalysisStageFinishedEvent(SystemEvent): + """Data analysis stage finished. + + This event is published when data analysis stage finished. + """ + + +@dataclass +class AdviceStageStartedEvent(SystemEvent): + """Advace producing stage started. + + This event is published when advice generation stage started. + """ + + +@dataclass +class AdviceStageFinishedEvent(SystemEvent): + """Advace producing stage finished. + + This event is published when advice generation stage finished. + """ + + +@dataclass +class CollectedDataEvent(SystemEvent): + """Collected data event. + + This event is published for every collected data item. + + :param data_item: collected data item + """ + + data_item: DataItem + + +@dataclass +class AnalyzedDataEvent(SystemEvent): + """Analyzed data event. + + This event is published for every analyzed data item. + + :param data_item: analyzed data item + """ + + data_item: DataItem + + +class EventHandler: + """Base class for the event handlers. + + Each event handler should derive from this base class. + """ + + def handle_event(self, event: Event) -> None: + """Handle the event. + + By default all published events are being passed to each + registered event handler. It is handler's responsibility + to filter events that it interested in. + """ + + +class DebugEventHandler(EventHandler): + """Event handler for debugging purposes. + + This handler could print every published event to the + standard output. + """ + + def __init__(self, with_stacktrace: bool = False) -> None: + """Init event handler. + + :param with_stacktrace: enable printing stacktrace of the + place where event publishing occurred. + """ + self.with_stacktrace = with_stacktrace + + def handle_event(self, event: Event) -> None: + """Handle event.""" + print(f"Got event {event}") + + if self.with_stacktrace: + traceback.print_stack() + + +class EventDispatcherMetaclass(type): + """Metaclass for event dispatching. + + It could be tedious to check type of the published event + inside event handler. Instead the following convention could be + established: if method name of the class starts with some + prefix then it is considered to be event handler of particular + type. + + This metaclass goes through the list of class methods and + links all methods with the prefix "on_" to the common dispatcher + method. + """ + + def __new__( + cls, + clsname: str, + bases: Tuple, + namespace: Dict[str, Any], + event_handler_method_prefix: str = "on_", + ) -> Any: + """Create event dispatcher and link event handlers.""" + new_class = super().__new__(cls, clsname, bases, namespace) + + @singledispatchmethod + def dispatcher(_self: Any, _event: Event) -> Any: + """Event dispatcher.""" + + # get all class methods which starts with particular prefix + event_handler_methods = ( + (item_name, item) + for item_name in dir(new_class) + if callable((item := getattr(new_class, item_name))) + and item_name.startswith(event_handler_method_prefix) + ) + + # link all collected event handlers to one dispatcher method + for method_name, method in event_handler_methods: + event_handler = dispatcher.register(method) + setattr(new_class, method_name, event_handler) + + # override default handle_event method, replace it with the + # dispatcher + setattr(new_class, "handle_event", dispatcher) + + return new_class + + +class EventDispatcher(EventHandler, metaclass=EventDispatcherMetaclass): + """Event dispatcher.""" + + +class EventPublisher(ABC): + """Base class for the event publisher. + + Event publisher is a intermidiate component between event emitter + and event consumer. + """ + + @abstractmethod + def register_event_handler(self, event_handler: EventHandler) -> None: + """Register event handler. + + :param event_handler: instance of the event handler + """ + + def register_event_handlers( + self, event_handlers: Optional[List[EventHandler]] + ) -> None: + """Register event handlers. + + Can be used for batch registration of the event handlers: + + :param event_handlers: list of the event handler instances + """ + if not event_handlers: + return + + for handler in event_handlers: + self.register_event_handler(handler) + + @abstractmethod + def publish_event(self, event: Event) -> None: + """Publish the event. + + Deliver the event to the all registered event handlers. + + :param event: event instance + """ + + +class DefaultEventPublisher(EventPublisher): + """Default event publishing implementation. + + Simple implementation that maintains list of the registered event + handlers. + """ + + def __init__(self) -> None: + """Init the event publisher.""" + self.handlers: List[EventHandler] = [] + + def register_event_handler(self, event_handler: EventHandler) -> None: + """Register the event handler. + + :param event_handler: instance of the event handler + """ + self.handlers.append(event_handler) + + def publish_event(self, event: Event) -> None: + """Publish the event. + + Publisher does not catch exceptions that could be raised by event handlers. + """ + for handler in self.handlers: + handler.handle_event(event) + + +@contextmanager +def stage( + publisher: EventPublisher, events: Tuple[Event, Event] +) -> Generator[None, None, None]: + """Generate events before and after stage. + + This context manager could be used to mark start/finish + execution of a particular logical part of the workflow. + """ + started, finished = events + + publisher.publish_event(started) + yield + publisher.publish_event(finished) + + +@contextmanager +def action( + publisher: EventPublisher, action_type: str, params: Optional[Dict] = None +) -> Generator[None, None, None]: + """Generate events before and after action.""" + action_started = ActionStartedEvent(action_type, params) + action_finished = ActionFinishedEvent(action_started.event_id) + + publisher.publish_event(action_started) + yield + publisher.publish_event(action_finished) + + +class SystemEventsHandler(EventDispatcher): + """System events handler.""" + + def on_execution_started(self, event: ExecutionStartedEvent) -> None: + """Handle ExecutionStarted event.""" + + def on_execution_finished(self, event: ExecutionFinishedEvent) -> None: + """Handle ExecutionFinished event.""" + + def on_execution_failed(self, event: ExecutionFailedEvent) -> None: + """Handle ExecutionFailed event.""" + + def on_data_collection_stage_started( + self, event: DataCollectionStageStartedEvent + ) -> None: + """Handle DataCollectionStageStarted event.""" + + def on_data_collection_stage_finished( + self, event: DataCollectionStageFinishedEvent + ) -> None: + """Handle DataCollectionStageFinished event.""" + + def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None: + """Handle DataCollectorSkipped event.""" + + def on_data_analysis_stage_started( + self, event: DataAnalysisStageStartedEvent + ) -> None: + """Handle DataAnalysisStageStartedEvent event.""" + + def on_data_analysis_stage_finished( + self, event: DataAnalysisStageFinishedEvent + ) -> None: + """Handle DataAnalysisStageFinishedEvent event.""" + + def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None: + """Handle AdviceStageStarted event.""" + + def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None: + """Handle AdviceStageFinished event.""" + + def on_collected_data(self, event: CollectedDataEvent) -> None: + """Handle CollectedData event.""" + + def on_analyzed_data(self, event: AnalyzedDataEvent) -> None: + """Handle AnalyzedData event.""" + + def on_action_started(self, event: ActionStartedEvent) -> None: + """Handle ActionStarted event.""" + + def on_action_finished(self, event: ActionFinishedEvent) -> None: + """Handle ActionFinished event.""" diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py new file mode 100644 index 0000000..d10ea5d --- /dev/null +++ b/src/mlia/core/helpers.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for various helper classes.""" +# pylint: disable=no-self-use, unused-argument +from typing import Any +from typing import List + + +class ActionResolver: + """Helper class for generating actions (e.g. commands with parameters).""" + + def apply_optimizations(self, **kwargs: Any) -> List[str]: + """Return action details for applying optimizations.""" + return [] + + def supported_operators_info(self) -> List[str]: + """Return action details for generating supported ops report.""" + return [] + + def check_performance(self) -> List[str]: + """Return action details for checking performance.""" + return [] + + def check_operator_compatibility(self) -> List[str]: + """Return action details for checking op compatibility.""" + return [] + + def operator_compatibility_details(self) -> List[str]: + """Return action details for getting more information about op compatibility.""" + return [] + + def optimization_details(self) -> List[str]: + """Return action detail for getting information about optimizations.""" + return [] + + +class APIActionResolver(ActionResolver): + """Helper class for the actions performed through API.""" diff --git a/src/mlia/core/mixins.py b/src/mlia/core/mixins.py new file mode 100644 index 0000000..ee03100 --- /dev/null +++ b/src/mlia/core/mixins.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Mixins module.""" +from typing import Any +from typing import Optional + +from mlia.core.context import Context + + +class ContextMixin: + """Mixin for injecting context object.""" + + context: Context + + def set_context(self, context: Context) -> None: + """Context setter.""" + self.context = context + + +class ParameterResolverMixin: + """Mixin for parameter resolving.""" + + context: Context + + def get_parameter( + self, + section: str, + name: str, + expected: bool = True, + expected_type: Optional[type] = None, + context: Optional[Context] = None, + ) -> Any: + """Get parameter value.""" + ctx = context or self.context + + if ctx.config_parameters is None: + raise Exception("Configuration parameters are not set") + + section_params = ctx.config_parameters.get(section) + if section_params is None or not isinstance(section_params, dict): + raise Exception( + f"Parameter section {section} has wrong format, " + "expected to be a dictionary" + ) + + value = section_params.get(name) + + if not value and expected: + raise Exception(f"Parameter {name} is not set") + + if value and expected_type is not None and not isinstance(value, expected_type): + raise Exception(f"Parameter {name} expected to have type {expected_type}") + + return value diff --git a/src/mlia/core/performance.py b/src/mlia/core/performance.py new file mode 100644 index 0000000..5433d5c --- /dev/null +++ b/src/mlia/core/performance.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for performance estimation.""" +from abc import abstractmethod +from typing import Callable +from typing import Generic +from typing import List +from typing import TypeVar + + +ModelType = TypeVar("ModelType") # pylint: disable=invalid-name +PerfMetricsType = TypeVar("PerfMetricsType") # pylint: disable=invalid-name + + +class PerformanceEstimator(Generic[ModelType, PerfMetricsType]): + """Base class for the performance estimation.""" + + @abstractmethod + def estimate(self, model: ModelType) -> PerfMetricsType: + """Estimate performance.""" + + +def estimate_performance( + original_model: ModelType, + estimator: PerformanceEstimator[ModelType, PerfMetricsType], + model_transformations: List[Callable[[ModelType], ModelType]], +) -> List[PerfMetricsType]: + """Estimate performance impact. + + This function estimates performance impact on model performance after + applying provided transformations/optimizations. + + :param original_model: object that represents a model, could be + instance of the model or path to the model. This depends on + provided performance estimator. + :param estimator: performance estimator + :param model_transformations: list of the callables each of those + returns object that represents optimized model + """ + original_metrics = estimator.estimate(original_model) + + optimized_metrics = [ + estimator.estimate(transform(original_model)) + for transform in model_transformations + ] + + return [original_metrics, *optimized_metrics] diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py new file mode 100644 index 0000000..1b75bb4 --- /dev/null +++ b/src/mlia/core/reporting.py @@ -0,0 +1,762 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Reporting module.""" +import csv +import json +import logging +from abc import ABC +from abc import abstractmethod +from collections import defaultdict +from contextlib import contextmanager +from contextlib import ExitStack +from dataclasses import dataclass +from functools import partial +from io import TextIOWrapper +from pathlib import Path +from textwrap import fill +from textwrap import indent +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import Generator +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np + +from mlia.core._typing import FileLike +from mlia.core._typing import OutputFormat +from mlia.core._typing import PathOrFileLike +from mlia.utils.console import apply_style +from mlia.utils.console import produce_table +from mlia.utils.logging import LoggerWriter +from mlia.utils.types import is_list_of + +logger = logging.getLogger(__name__) + + +class Report(ABC): + """Abstract class for the report.""" + + @abstractmethod + def to_json(self, **kwargs: Any) -> Any: + """Convert to json serializible format.""" + + @abstractmethod + def to_csv(self, **kwargs: Any) -> List[Any]: + """Convert to csv serializible format.""" + + @abstractmethod + def to_plain_text(self, **kwargs: Any) -> str: + """Convert to human readable format.""" + + +class ReportItem: + """Item of the report.""" + + def __init__( + self, + name: str, + alias: Optional[str] = None, + value: Optional[Union[str, int, "Cell"]] = None, + nested_items: Optional[List["ReportItem"]] = None, + ) -> None: + """Init the report item.""" + self.name = name + self.alias = alias + self.value = value + self.nested_items = nested_items or [] + + @property + def compound(self) -> bool: + """Return true if item has nested items.""" + return self.nested_items is not None and len(self.nested_items) > 0 + + @property + def raw_value(self) -> Any: + """Get actual item value.""" + val = self.value + if isinstance(val, Cell): + return val.value + + return val + + +@dataclass +class Format: + """Column or cell format. + + Format could be applied either to a column or an individual cell. + + :param wrap_width: width of the wrapped text value + :param str_fmt: string format to be applied to the value + :param style: text style + """ + + wrap_width: Optional[int] = None + str_fmt: Optional[Union[str, Callable[[Any], str]]] = None + style: Optional[str] = None + + +@dataclass +class Cell: + """Cell definition. + + This a wrapper class for a particular value in the table. Could be used + for applying specific format to this value. + """ + + value: Any + fmt: Optional[Format] = None + + def _apply_style(self, value: str) -> str: + """Apply style to the value.""" + if self.fmt and self.fmt.style: + value = apply_style(value, self.fmt.style) + + return value + + def _get_value(self) -> str: + """Return cell value.""" + if self.fmt: + if isinstance(self.fmt.str_fmt, str): + return "{:{fmt}}".format(self.value, fmt=self.fmt.str_fmt) + + if callable(self.fmt.str_fmt): + return self.fmt.str_fmt(self.value) + + return str(self.value) + + def __str__(self) -> str: + """Return string representation.""" + val = self._get_value() + return self._apply_style(val) + + def to_csv(self) -> Any: + """Cell definition for csv.""" + return self.value + + def to_json(self) -> Any: + """Cell definition for json.""" + return self.value + + +class CountAwareCell(Cell): + """Count aware cell.""" + + def __init__( + self, + value: Optional[Union[int, float]], + singular: str, + plural: str, + format_string: str = ",d", + ): + """Init cell instance.""" + self.unit = singular if value == 1 else plural + + def format_value(val: Optional[Union[int, float]]) -> str: + """Provide string representation for the value.""" + if val is None: + return "" + + if val == 1: + return f"1 {singular}" + + return f"{val:{format_string}} {plural}" + + super().__init__(value, Format(str_fmt=format_value)) + + def to_csv(self) -> Any: + """Cell definition for csv.""" + return {"value": self.value, "unit": self.unit} + + def to_json(self) -> Any: + """Cell definition for json.""" + return {"value": self.value, "unit": self.unit} + + +class BytesCell(CountAwareCell): + """Cell that represents memory size.""" + + def __init__(self, value: Optional[int]) -> None: + """Init cell instance.""" + super().__init__(value, "byte", "bytes") + + +class CyclesCell(CountAwareCell): + """Cell that represents cycles.""" + + def __init__(self, value: Optional[Union[int, float]]) -> None: + """Init cell instance.""" + super().__init__(value, "cycle", "cycles", ",.0f") + + +class ClockCell(CountAwareCell): + """Cell that represents clock value.""" + + def __init__(self, value: Optional[Union[int, float]]) -> None: + """Init cell instance.""" + super().__init__(value, "Hz", "Hz", ",.0f") + + +class Column: + """Column definition.""" + + def __init__( + self, + header: str, + alias: Optional[str] = None, + fmt: Optional[Format] = None, + only_for: Optional[List[str]] = None, + ) -> None: + """Init column definition. + + :param header: column's header + :param alias: columns's alias, could be used as column's name + :param fmt: format that will be applied for all column's values + :param only_for: list of the formats where this column should be + represented. May be used to differentiate data representation in + different formats + """ + self.header = header + self.alias = alias + self.fmt = fmt + self.only_for = only_for + + def supports_format(self, fmt: str) -> bool: + """Return true if column should be shown.""" + return not self.only_for or fmt in self.only_for + + +class NestedReport(Report): + """Report with nested items.""" + + def __init__(self, name: str, alias: str, items: List[ReportItem]) -> None: + """Init nested report.""" + self.name = name + self.alias = alias + self.items = items + + def to_csv(self, **kwargs: Any) -> List[Any]: + """Convert to csv serializible format.""" + result = {} + + def collect_item_values( + item: ReportItem, + _parent: Optional[ReportItem], + _prev: Optional[ReportItem], + _level: int, + ) -> None: + """Collect item values into a dictionary..""" + if item.value is None: + return + + if not isinstance(item.value, Cell): + result[item.alias] = item.raw_value + return + + csv_value = item.value.to_csv() + if isinstance(csv_value, dict): + csv_value = { + f"{item.alias}_{key}": value for key, value in csv_value.items() + } + else: + csv_value = {item.alias: csv_value} + + result.update(csv_value) + + self._traverse(self.items, collect_item_values) + + # make list out of the result dictionary + # first element - keys of the dictionary as headers + # second element - list of the dictionary values + return list(zip(*result.items())) + + def to_json(self, **kwargs: Any) -> Any: + """Convert to json serializible format.""" + per_parent: Dict[Optional[ReportItem], Dict] = defaultdict(dict) + result = per_parent[None] + + def collect_as_dicts( + item: ReportItem, + parent: Optional[ReportItem], + _prev: Optional[ReportItem], + _level: int, + ) -> None: + """Collect item values as nested dictionaries.""" + parent_dict = per_parent[parent] + + if item.compound: + item_dict = per_parent[item] + parent_dict[item.alias] = item_dict + else: + out_dis = ( + item.value.to_json() + if isinstance(item.value, Cell) + else item.raw_value + ) + parent_dict[item.alias] = out_dis + + self._traverse(self.items, collect_as_dicts) + + return {self.alias: result} + + def to_plain_text(self, **kwargs: Any) -> str: + """Convert to human readable format.""" + header = f"{self.name}:\n" + processed_items = [] + + def convert_to_text( + item: ReportItem, + _parent: Optional[ReportItem], + prev: Optional[ReportItem], + level: int, + ) -> None: + """Convert item to text representation.""" + if level >= 1 and prev is not None and (item.compound or prev.compound): + processed_items.append("") + + val = self._item_value(item, level) + processed_items.append(val) + + self._traverse(self.items, convert_to_text) + body = "\n".join(processed_items) + + return header + body + + @staticmethod + def _item_value( + item: ReportItem, level: int, tab_size: int = 2, column_width: int = 35 + ) -> str: + """Get report item value.""" + shift = " " * tab_size * level + if item.value is None: + return f"{shift}{item.name}:" + + col1 = f"{shift}{item.name}".ljust(column_width) + col2 = f"{item.value}".rjust(column_width) + + return col1 + col2 + + def _traverse( + self, + items: List[ReportItem], + visit_item: Callable[ + [ReportItem, Optional[ReportItem], Optional[ReportItem], int], None + ], + level: int = 1, + parent: Optional[ReportItem] = None, + ) -> None: + """Traverse through items.""" + prev = None + for item in items: + visit_item(item, parent, prev, level) + + self._traverse(item.nested_items, visit_item, level + 1, item) + prev = item + + +class Table(Report): + """Table definition. + + This class could be used for representing tabular data. + """ + + def __init__( + self, + columns: List[Column], + rows: Collection, + name: str, + alias: Optional[str] = None, + notes: Optional[str] = None, + ) -> None: + """Init table definition. + + :param columns: list of the table's columns + :param rows: list of the table's rows + :param name: name of the table + :param alias: alias for the table + """ + self.columns = columns + self.rows = rows + self.name = name + self.alias = alias + self.notes = notes + + def to_json(self, **kwargs: Any) -> Iterable: + """Convert table to dict object.""" + + def item_to_json(item: Any) -> Any: + value = item + if isinstance(item, Cell): + value = item.value + + if isinstance(value, Table): + return value.to_json() + + return value + + json_data = [ + { + col.alias or col.header: item_to_json(item) + for (item, col) in zip(row, self.columns) + if col.supports_format("json") + } + for row in self.rows + ] + + if not self.alias: + return json_data + + return {self.alias: json_data} + + def to_plain_text(self, **kwargs: Any) -> str: + """Produce report in human readable format.""" + nested = kwargs.get("nested", False) + show_headers = kwargs.get("show_headers", True) + show_title = kwargs.get("show_title", True) + table_style = kwargs.get("table_style", "default") + space = kwargs.get("space", False) + + headers = ( + [] if (nested or not show_headers) else [c.header for c in self.columns] + ) + + def item_to_plain_text(item: Any, col: Column) -> str: + """Convert item to text.""" + if isinstance(item, Table): + return item.to_plain_text(nested=True, **kwargs) + + if is_list_of(item, str): + as_text = "\n".join(item) + else: + as_text = str(item) + + if col.fmt and col.fmt.wrap_width: + as_text = fill(as_text, col.fmt.wrap_width) + + return as_text + + title = "" + if show_title and not nested: + title = f"{self.name}:\n" + + if space in (True, "top"): + title = "\n" + title + + footer = "" + if space in (True, "bottom"): + footer = "\n" + if self.notes: + footer = "\n" + self.notes + + formatted_rows = ( + ( + item_to_plain_text(item, col) + for item, col in zip(row, self.columns) + if col.supports_format("plain_text") + ) + for row in self.rows + ) + + if space == "between": + formatted_table = "\n\n".join( + produce_table([row], table_style=table_style) for row in formatted_rows + ) + else: + formatted_table = produce_table( + formatted_rows, + headers=headers, + table_style="nested" if nested else table_style, + ) + + return title + formatted_table + footer + + def to_csv(self, **kwargs: Any) -> List[Any]: + """Convert table to csv format.""" + headers = [[c.header for c in self.columns if c.supports_format("csv")]] + + def item_data(item: Any) -> Any: + if isinstance(item, Cell): + return item.value + + if isinstance(item, Table): + return ";".join( + str(item_data(cell)) for row in item.rows for cell in row + ) + + return item + + rows = [ + [ + item_data(item) + for (item, col) in zip(row, self.columns) + if col.supports_format("csv") + ] + for row in self.rows + ] + + return headers + rows + + +class SingleRow(Table): + """Table with a single row.""" + + def to_plain_text(self, **kwargs: Any) -> str: + """Produce report in human readable format.""" + if len(self.rows) != 1: + raise Exception("Table should have only one row") + + items = "\n".join( + column.header.ljust(35) + str(item).rjust(25) + for row in self.rows + for item, column in zip(row, self.columns) + if column.supports_format("plain_text") + ) + + return "\n".join([f"{self.name}:", indent(items, " ")]) + + +class CompoundReport(Report): + """Compound report. + + This class could be used for producing multiple reports at once. + """ + + def __init__(self, reports: List[Report]) -> None: + """Init compound report instance.""" + self.reports = reports + + def to_json(self, **kwargs: Any) -> Any: + """Convert to json serializible format. + + Method attempts to create compound dictionary based on provided + parts. + """ + result: Dict[str, Any] = {} + for item in self.reports: + result.update(item.to_json(**kwargs)) + + return result + + def to_csv(self, **kwargs: Any) -> List[Any]: + """Convert to csv serializible format. + + CSV format does support only one table. In order to be able to export + multiply tables they should be merged before that. This method tries to + do next: + + - if all tables have the same length then just concatenate them + - if one table has many rows and other just one (two with headers), then + for each row in table with many rows duplicate values from other tables + """ + csv_data = [item.to_csv() for item in self.reports] + lengths = [len(csv_item_data) for csv_item_data in csv_data] + + same_length = len(set(lengths)) == 1 + if same_length: + # all lists are of the same length, merge them into one + return [[cell for item in row for cell in item] for row in zip(*csv_data)] + + main_obj_indexes = [i for i, item in enumerate(csv_data) if len(item) > 2] + one_main_obj = len(main_obj_indexes) == 1 + + reference_obj_indexes = [i for i, item in enumerate(csv_data) if len(item) == 2] + other_only_ref_objs = len(reference_obj_indexes) == len(csv_data) - 1 + + if one_main_obj and other_only_ref_objs: + main_obj = csv_data[main_obj_indexes[0]] + return [ + item + + [ + ref_item + for ref_table_index in reference_obj_indexes + for ref_item in csv_data[ref_table_index][0 if i == 0 else 1] + ] + for i, item in enumerate(main_obj) + ] + + # write tables one after another if there is no other options + return [row for item in csv_data for row in item] + + def to_plain_text(self, **kwargs: Any) -> str: + """Convert to human readable format.""" + return "\n".join(item.to_plain_text(**kwargs) for item in self.reports) + + +class CompoundFormatter: + """Compound data formatter.""" + + def __init__(self, formatters: List[Callable]) -> None: + """Init compound formatter.""" + self.formatters = formatters + + def __call__(self, data: Any) -> Report: + """Produce report.""" + reports = [formatter(item) for item, formatter in zip(data, self.formatters)] + return CompoundReport(reports) + + +class CustomJSONEncoder(json.JSONEncoder): + """Custom JSON encoder.""" + + def default(self, o: Any) -> Any: + """Support numpy types.""" + if isinstance(o, np.integer): + return int(o) + + if isinstance(o, np.floating): + return float(o) + + return json.JSONEncoder.default(self, o) + + +def json_reporter(report: Report, output: FileLike, **kwargs: Any) -> None: + """Produce report in json format.""" + json_str = json.dumps(report.to_json(**kwargs), indent=4, cls=CustomJSONEncoder) + print(json_str, file=output) + + +def text_reporter(report: Report, output: FileLike, **kwargs: Any) -> None: + """Produce report in text format.""" + print(report.to_plain_text(**kwargs), file=output) + + +def csv_reporter(report: Report, output: FileLike, **kwargs: Any) -> None: + """Produce report in csv format.""" + csv_writer = csv.writer(output) + csv_writer.writerows(report.to_csv(**kwargs)) + + +def produce_report( + data: Any, + formatter: Callable[[Any], Report], + fmt: OutputFormat = "plain_text", + output: Optional[PathOrFileLike] = None, + **kwargs: Any, +) -> None: + """Produce report based on provided data.""" + # check if provided format value is supported + formats = {"json": json_reporter, "plain_text": text_reporter, "csv": csv_reporter} + if fmt not in formats: + raise Exception(f"Unknown format {fmt}") + + if output is None: + output = cast(TextIOWrapper, LoggerWriter(logger, logging.INFO)) + + with ExitStack() as exit_stack: + if isinstance(output, (str, Path)): + # open file and add it to the ExitStack context manager + # in that case it will be automatically closed + stream = exit_stack.enter_context(open(output, "w", encoding="utf-8")) + else: + stream = cast(TextIOWrapper, output) + + # convert data into serializable form + formatted_data = formatter(data) + # find handler for the format + format_handler = formats[fmt] + # produce report in requested format + format_handler(formatted_data, stream, **kwargs) + + +class Reporter: + """Reporter class.""" + + def __init__( + self, + formatter_resolver: Callable[[Any], Callable[[Any], Report]], + output_format: OutputFormat = "plain_text", + print_as_submitted: bool = True, + ) -> None: + """Init reporter instance.""" + self.formatter_resolver = formatter_resolver + self.output_format = output_format + self.print_as_submitted = print_as_submitted + + self.data: List[Tuple[Any, Callable[[Any], Report]]] = [] + self.delayed: List[Tuple[Any, Callable[[Any], Report]]] = [] + + def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None: + """Submit data for the report.""" + if self.print_as_submitted and not delay_print: + produce_report( + data_item, + self.formatter_resolver(data_item), + fmt="plain_text", + **kwargs, + ) + + formatter = _apply_format_parameters( + self.formatter_resolver(data_item), self.output_format, **kwargs + ) + self.data.append((data_item, formatter)) + + if delay_print: + self.delayed.append((data_item, formatter)) + + def print_delayed(self) -> None: + """Print delayed reports.""" + if not self.delayed: + return + + data, formatters = zip(*self.delayed) + produce_report( + data, + formatter=CompoundFormatter(formatters), + fmt="plain_text", + ) + self.delayed = [] + + def generate_report(self, output: Optional[PathOrFileLike]) -> None: + """Generate report.""" + already_printed = ( + self.print_as_submitted + and self.output_format == "plain_text" + and output is None + ) + if not self.data or already_printed: + return + + data, formatters = zip(*self.data) + produce_report( + data, + formatter=CompoundFormatter(formatters), + fmt=self.output_format, + output=output, + ) + + +@contextmanager +def get_reporter( + output_format: OutputFormat, + output: Optional[PathOrFileLike], + formatter_resolver: Callable[[Any], Callable[[Any], Report]], +) -> Generator[Reporter, None, None]: + """Get reporter and generate report.""" + reporter = Reporter(formatter_resolver, output_format) + + yield reporter + + reporter.generate_report(output) + + +def _apply_format_parameters( + formatter: Callable[[Any], Report], output_format: OutputFormat, **kwargs: Any +) -> Callable[[Any], Report]: + """Wrap report method.""" + + def wrapper(data: Any) -> Report: + report = formatter(data) + method_name = f"to_{output_format}" + method = getattr(report, method_name) + setattr(report, method_name, partial(method, **kwargs)) + + return report + + return wrapper diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py new file mode 100644 index 0000000..0245087 --- /dev/null +++ b/src/mlia/core/workflow.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for executors. + +This module contains implementation of the workflow +executors. +""" +import itertools +from abc import ABC +from abc import abstractmethod +from functools import wraps +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple + +from mlia.core.advice_generation import Advice +from mlia.core.advice_generation import AdviceEvent +from mlia.core.advice_generation import AdviceProducer +from mlia.core.common import DataItem +from mlia.core.context import Context +from mlia.core.data_analysis import DataAnalyzer +from mlia.core.data_collection import DataCollector +from mlia.core.errors import FunctionalityNotSupportedError +from mlia.core.events import AdviceStageFinishedEvent +from mlia.core.events import AdviceStageStartedEvent +from mlia.core.events import AnalyzedDataEvent +from mlia.core.events import CollectedDataEvent +from mlia.core.events import DataAnalysisStageFinishedEvent +from mlia.core.events import DataAnalysisStageStartedEvent +from mlia.core.events import DataCollectionStageFinishedEvent +from mlia.core.events import DataCollectionStageStartedEvent +from mlia.core.events import DataCollectorSkippedEvent +from mlia.core.events import Event +from mlia.core.events import ExecutionFailedEvent +from mlia.core.events import ExecutionFinishedEvent +from mlia.core.events import ExecutionStartedEvent +from mlia.core.events import stage +from mlia.core.mixins import ContextMixin + + +class WorkflowExecutor(ABC): + """Base workflow executor.""" + + @abstractmethod + def run(self) -> None: + """Run the module.""" + + +STAGE_COLLECTION = ( + DataCollectionStageStartedEvent(), + DataCollectionStageFinishedEvent(), +) +STAGE_ANALYSIS = (DataAnalysisStageStartedEvent(), DataAnalysisStageFinishedEvent()) +STAGE_ADVICE = (AdviceStageStartedEvent(), AdviceStageFinishedEvent()) + + +def on_stage(stage_events: Tuple[Event, Event]) -> Callable: + """Mark start/finish of the stage with appropriate events.""" + + def wrapper(method: Callable) -> Callable: + """Wrap method.""" + + @wraps(method) + def publish_events(self: Any, *args: Any, **kwargs: Any) -> Any: + """Publish events before and after execution.""" + with stage(self.context.event_publisher, stage_events): + return method(self, *args, **kwargs) + + return publish_events + + return wrapper + + +class DefaultWorkflowExecutor(WorkflowExecutor): + """Default module executor. + + This is a default implementation of the workflow executor. + All components are launched sequentually in the same process. + """ + + def __init__( + self, + context: Context, + collectors: Sequence[DataCollector], + analyzers: Sequence[DataAnalyzer], + producers: Sequence[AdviceProducer], + before_start_events: Optional[Sequence[Event]] = None, + ): + """Init default workflow executor. + + :param context: Context instance + :param collectors: List of the data collectors + :param analyzers: List of the data analyzers + :param producers: List of the advice producers + :param before_start_events: Optional list of the custom events that + should be published before start of the worfkow execution. + """ + self.context = context + self.collectors = collectors + self.analyzers = analyzers + self.producers = producers + self.before_start_events = before_start_events + + def run(self) -> None: + """Run the workflow.""" + self.inject_context() + self.context.register_event_handlers() + + try: + self.publish(ExecutionStartedEvent()) + + self.before_start() + + collected_data = self.collect_data() + analyzed_data = self.analyze_data(collected_data) + + self.produce_advice(analyzed_data) + except Exception as err: # pylint: disable=broad-except + self.publish(ExecutionFailedEvent(err)) + else: + self.publish(ExecutionFinishedEvent()) + + def before_start(self) -> None: + """Run actions before start of the workflow execution.""" + events = self.before_start_events or [] + for event in events: + self.publish(event) + + @on_stage(STAGE_COLLECTION) + def collect_data(self) -> List[DataItem]: + """Collect data. + + Run each of data collector components and return list of + the collected data items. + """ + collected_data = [] + for collector in self.collectors: + try: + if (data_item := collector.collect_data()) is not None: + collected_data.append(data_item) + self.publish(CollectedDataEvent(data_item)) + except FunctionalityNotSupportedError as err: + self.publish(DataCollectorSkippedEvent(collector.name(), str(err))) + + return collected_data + + @on_stage(STAGE_ANALYSIS) + def analyze_data(self, collected_data: List[DataItem]) -> List[DataItem]: + """Analyze data. + + Pass each collected data item into each data analyzer and + return analyzed data. + + :param collected_data: list of collected data items + """ + analyzed_data = [] + for analyzer in self.analyzers: + for item in collected_data: + analyzer.analyze_data(item) + + for data_item in analyzer.get_analyzed_data(): + analyzed_data.append(data_item) + + self.publish(AnalyzedDataEvent(data_item)) + return analyzed_data + + @on_stage(STAGE_ADVICE) + def produce_advice(self, analyzed_data: List[DataItem]) -> None: + """Produce advice. + + Pass each analyzed data item into each advice producer and + publish generated advice. + + :param analyzed_data: list of analyzed data items + """ + for producer in self.producers: + for data_item in analyzed_data: + producer.produce_advice(data_item) + + advice = producer.get_advice() + if isinstance(advice, Advice): + advice = [advice] + + for item in advice: + self.publish(AdviceEvent(item)) + + def inject_context(self) -> None: + """Inject context object into components. + + Inject context object into components that supports context + injection. + """ + context_aware_components = ( + comp + for comp in itertools.chain( + self.collectors, + self.analyzers, + self.producers, + ) + if isinstance(comp, ContextMixin) + ) + + for component in context_aware_components: + component.set_context(self.context) + + def publish(self, event: Event) -> None: + """Publish event. + + Helper method for event publising. + + :param event: event instance + """ + self.context.event_publisher.publish_event(event) diff --git a/src/mlia/devices/__init__.py b/src/mlia/devices/__init__.py new file mode 100644 index 0000000..d533f4a --- /dev/null +++ b/src/mlia/devices/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Devices module.""" diff --git a/src/mlia/devices/config.py b/src/mlia/devices/config.py new file mode 100644 index 0000000..7ab6b43 --- /dev/null +++ b/src/mlia/devices/config.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""IP configuration module.""" + + +class IPConfiguration: # pylint: disable=too-few-public-methods + """Base class for IP configuration.""" + + def __init__(self, target: str) -> None: + """Init IP configuration instance.""" + self.target = target diff --git a/src/mlia/devices/ethosu/__init__.py b/src/mlia/devices/ethosu/__init__.py new file mode 100644 index 0000000..73925e1 --- /dev/null +++ b/src/mlia/devices/ethosu/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Ethos-U devices module.""" diff --git a/src/mlia/devices/ethosu/advice_generation.py b/src/mlia/devices/ethosu/advice_generation.py new file mode 100644 index 0000000..7a818c9 --- /dev/null +++ b/src/mlia/devices/ethosu/advice_generation.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Ethos-U advice generation.""" +from functools import singledispatchmethod +from typing import List +from typing import Union + +from mlia.core.advice_generation import Advice +from mlia.core.advice_generation import advice_category +from mlia.core.advice_generation import ContextAwareAdviceProducer +from mlia.core.advice_generation import FactBasedAdviceProducer +from mlia.core.common import AdviceCategory +from mlia.core.common import DataItem +from mlia.devices.ethosu.data_analysis import AllOperatorsSupportedOnNPU +from mlia.devices.ethosu.data_analysis import HasCPUOnlyOperators +from mlia.devices.ethosu.data_analysis import HasUnsupportedOnNPUOperators +from mlia.devices.ethosu.data_analysis import OptimizationResults +from mlia.nn.tensorflow.optimizations.select import OptimizationSettings + + +class EthosUAdviceProducer(FactBasedAdviceProducer): + """Ethos-U advice producer.""" + + @singledispatchmethod + def produce_advice(self, data_item: DataItem) -> None: + """Produce advice.""" + + @produce_advice.register + @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL) + def handle_cpu_only_ops(self, data_item: HasCPUOnlyOperators) -> None: + """Advice for CPU only operators.""" + cpu_only_ops = ",".join(sorted(set(data_item.cpu_only_ops))) + cpu_only_ops_num = len(data_item.cpu_only_ops) + + self.add_advice( + [ + f"You have at least {cpu_only_ops_num} " + f"operator{'s' if cpu_only_ops_num > 1 else ''} that is CPU " + f"only: {cpu_only_ops}.", + "Using operators that are supported by the NPU will " + "improve performance.", + ] + + self.context.action_resolver.supported_operators_info() + ) + + @produce_advice.register + @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL) + def handle_unsupported_operators( + self, data_item: HasUnsupportedOnNPUOperators + ) -> None: + """Advice for the unsupported operators.""" + self.add_advice( + [ + f"You have {data_item.npu_unsupported_ratio*100:.0f}% of operators " + "that cannot be placed on the NPU.", + "For better performance, please review the reasons reported " + "in the table, and adjust the model accordingly " + "where possible.", + ] + ) + + @produce_advice.register + @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL) + def handle_all_operators_supported( + self, _data_item: AllOperatorsSupportedOnNPU + ) -> None: + """Advice if all operators supported.""" + self.add_advice( + [ + "You don't have any unsupported operators, your model will " + "run completely on NPU." + ] + + self.context.action_resolver.check_performance() + ) + + @produce_advice.register + @advice_category(AdviceCategory.OPTIMIZATION, AdviceCategory.ALL) + def handle_optimization_results(self, data_item: OptimizationResults) -> None: + """Advice based on optimization results.""" + if not data_item.diffs or len(data_item.diffs) != 1: + return + + optim_details = data_item.diffs[0] + metrics = [ + (metric_name, optim_details.opt_diffs[metric_key]) + for (metric_name, metric_key) in ( + ("DRAM used (KB)", "dram"), + ("SRAM used (KB)", "sram"), + ("On chip flash used (KB)", "on_chip_flash"), + ("Off chip flash used (KB)", "off_chip_flash"), + ("NPU total cycles", "npu_total_cycles"), + ) + if metric_key in optim_details.opt_diffs + and not optim_details.opt_diffs[metric_key].same + ] + + improved = [ + f"- You have achieved {abs(metric_value.diff):.2f}% performance " + f"improvement in {metric_name}" + for metric_name, metric_value in metrics + if metric_value.improved + ] + + degraded = [ + f"- {metric_name} have degraded by {abs(metric_value.diff):.2f}%" + for metric_name, metric_value in metrics + if metric_value.degraded + ] + + opts = ", ".join(str(s) for s in optim_details.opt_type) + messages = [f"With the selected optimization ({opts})", *improved, *degraded] + + if improved: + if next_optimization_target := self.get_next_optimization_targets( + optim_details.opt_type + ): + next_optimization_target_as_str = " and/or ".join( + str(item) for item in next_optimization_target + ) + + messages.append( + "You can try to push the optimization target higher " + f"(e.g. {next_optimization_target_as_str}) " + "to check if those results can be further improved." + ) + messages += self.context.action_resolver.apply_optimizations( + opt_settings=next_optimization_target + ) + + elif degraded: + messages.append( + "The performance seems to have degraded after " + "applying the selected optimizations, " + "try exploring different optimization types/targets." + ) + + self.add_advice(messages) + + self.add_advice( + [ + "The applied tooling techniques have an impact " + "on accuracy. Additional hyperparameter tuning may be required " + "after any optimization." + ] + ) + + @staticmethod + def get_next_optimization_targets( + opt_type: List[OptimizationSettings], + ) -> List[OptimizationSettings]: + """Get next optimization targets.""" + next_targets = (item.next_target() for item in opt_type) + + # filter out targets that have not been changed + valid_targets = [ + next_ + for next_, old in zip(next_targets, opt_type) + if ( + old.optimization_type == "pruning" + and old.optimization_target < next_.optimization_target + ) + or ( + old.optimization_type == "clustering" + and old.optimization_target > next_.optimization_target + ) + ] + return valid_targets + + +class EthosUStaticAdviceProducer(ContextAwareAdviceProducer): + """Advice producer that not depends on input data.""" + + def produce_advice(self, data_item: DataItem) -> None: + """Do not process passed data items.""" + + def get_advice(self) -> Union[Advice, List[Advice]]: + """Return predefined advice based on category.""" + if self.context.advice_category is None: + return [] + + advice_per_category = { + AdviceCategory.PERFORMANCE: [ + Advice( + [ + "You can improve the inference time by using only operators " + "that are supported by the NPU.", + ] + + self.context.action_resolver.check_operator_compatibility() + ), + Advice( + [ + "Check if you can improve the performance by applying " + "tooling techniques to your model." + ] + + self.context.action_resolver.apply_optimizations() + ), + ], + AdviceCategory.OPTIMIZATION: [ + Advice( + [ + "For better performance, make sure that all the operators " + "of your final TFLite model are supported by the NPU.", + ] + + self.context.action_resolver.operator_compatibility_details() + ) + ], + } + + return advice_per_category.get(self.context.advice_category, []) diff --git a/src/mlia/devices/ethosu/advisor.py b/src/mlia/devices/ethosu/advisor.py new file mode 100644 index 0000000..802826b --- /dev/null +++ b/src/mlia/devices/ethosu/advisor.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Ethos-U MLIA module.""" +from pathlib import Path +from typing import List +from typing import Optional + +from mlia.core.advice_generation import AdviceProducer +from mlia.core.advisor import InferenceAdvisor +from mlia.core.common import AdviceCategory +from mlia.core.context import Context +from mlia.core.data_analysis import DataAnalyzer +from mlia.core.data_collection import DataCollector +from mlia.core.mixins import ParameterResolverMixin +from mlia.core.workflow import DefaultWorkflowExecutor +from mlia.core.workflow import WorkflowExecutor +from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer +from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer +from mlia.devices.ethosu.config import EthosUConfiguration +from mlia.devices.ethosu.config import get_target +from mlia.devices.ethosu.data_analysis import EthosUDataAnalyzer +from mlia.devices.ethosu.data_collection import EthosUOperatorCompatibility +from mlia.devices.ethosu.data_collection import EthosUOptimizationPerformance +from mlia.devices.ethosu.data_collection import EthosUPerformance +from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent + + +class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin): + """Ethos-U Inference Advisor.""" + + @classmethod + def name(cls) -> str: + """Return name of the advisor.""" + return "ethos_u_inference_advisor" + + def configure(self, context: Context) -> WorkflowExecutor: + """Configure advisor execution.""" + model = self._get_model(context) + device = self._get_device(context) + backends = self._get_backends(context) + + collectors = self._get_collectors(context, model, device, backends) + analyzers = self._get_analyzers() + producers = self._get_advice_producers() + + return DefaultWorkflowExecutor( + context, + collectors, + analyzers, + producers, + before_start_events=[ + EthosUAdvisorStartedEvent(device=device, model=model), + ], + ) + + def _get_collectors( + self, + context: Context, + model: Path, + device: EthosUConfiguration, + backends: Optional[List[str]], + ) -> List[DataCollector]: + """Get collectors.""" + collectors: List[DataCollector] = [] + + if context.any_category_enabled( + AdviceCategory.OPERATORS, + AdviceCategory.ALL, + ): + collectors.append(EthosUOperatorCompatibility(model, device)) + + if context.category_enabled(AdviceCategory.PERFORMANCE): + collectors.append(EthosUPerformance(model, device, backends)) + + if context.any_category_enabled( + AdviceCategory.OPTIMIZATION, + AdviceCategory.ALL, + ): + optimization_settings = self._get_optimization_settings(context) + collectors.append( + EthosUOptimizationPerformance( + model, device, optimization_settings, backends + ) + ) + + return collectors + + @staticmethod + def _get_analyzers() -> List[DataAnalyzer]: + """Return data analyzers.""" + return [ + EthosUDataAnalyzer(), + ] + + @staticmethod + def _get_advice_producers() -> List[AdviceProducer]: + """Return advice producers.""" + return [ + EthosUAdviceProducer(), + EthosUStaticAdviceProducer(), + ] + + def _get_device(self, context: Context) -> EthosUConfiguration: + """Get device.""" + device_params = self.get_parameter( + self.name(), + "device", + expected_type=dict, + context=context, + ) + + try: + target_profile = device_params["target_profile"] + except KeyError as err: + raise Exception("Unable to get device details") from err + + return get_target(target_profile) + + def _get_model(self, context: Context) -> Path: + """Get path to the model.""" + model_param = self.get_parameter( + self.name(), + "model", + expected_type=str, + context=context, + ) + + if not (model := Path(model_param)).exists(): + raise Exception(f"Path {model} does not exist") + + return model + + def _get_optimization_settings(self, context: Context) -> List[List[dict]]: + """Get optimization settings.""" + return self.get_parameter( # type: ignore + EthosUOptimizationPerformance.name(), + "optimizations", + expected_type=list, + expected=False, + context=context, + ) + + def _get_backends(self, context: Context) -> Optional[List[str]]: + """Get list of backends.""" + return self.get_parameter( # type: ignore + self.name(), + "backends", + expected_type=list, + expected=False, + context=context, + ) diff --git a/src/mlia/devices/ethosu/config.py b/src/mlia/devices/ethosu/config.py new file mode 100644 index 0000000..cecbb27 --- /dev/null +++ b/src/mlia/devices/ethosu/config.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Ethos-U configuration.""" +import logging +from typing import Any +from typing import Dict + +from mlia.devices.config import IPConfiguration +from mlia.tools.vela_wrapper import resolve_compiler_config +from mlia.tools.vela_wrapper import VelaCompilerOptions +from mlia.utils.filesystem import get_profile +from mlia.utils.filesystem import get_vela_config + + +logger = logging.getLogger(__name__) + + +class EthosUConfiguration(IPConfiguration): + """Ethos-U configuration.""" + + def __init__(self, target_profile: str) -> None: + """Init Ethos-U target configuration.""" + target_data = get_profile(target_profile) + _check_target_data_complete(target_data) + + target = target_data["target"] + super().__init__(target) + + mac = target_data["mac"] + _check_device_options_valid(target, mac) + + self.mac = mac + self.compiler_options = VelaCompilerOptions( + system_config=target_data["system_config"], + memory_mode=target_data["memory_mode"], + config_files=str(get_vela_config()), + accelerator_config=f"{self.target}-{mac}", # type: ignore + ) + + @property + def resolved_compiler_config(self) -> Dict[str, Any]: + """Resolve compiler configuration.""" + return resolve_compiler_config(self.compiler_options) + + def __str__(self) -> str: + """Return string representation.""" + return ( + f"Ethos-U target={self.target} " + f"mac={self.mac} " + f"compiler_options={self.compiler_options}" + ) + + def __repr__(self) -> str: + """Return string representation.""" + return f"" + + +def get_target(target_profile: str) -> EthosUConfiguration: + """Get target instance based on provided params.""" + if not target_profile: + raise Exception("No target profile given") + + return EthosUConfiguration(target_profile) + + +def _check_target_data_complete(target_data: Dict[str, Any]) -> None: + """Check if profile contains all needed data.""" + mandatory_keys = {"target", "mac", "system_config", "memory_mode"} + missing_keys = sorted(mandatory_keys - target_data.keys()) + + if missing_keys: + raise Exception(f"Mandatory fields missing from target profile: {missing_keys}") + + +def _check_device_options_valid(target: str, mac: int) -> None: + """Check if mac is valid for selected device.""" + target_mac_ranges = { + "ethos-u55": [32, 64, 128, 256], + "ethos-u65": [256, 512], + } + + if target not in target_mac_ranges: + raise Exception(f"Unsupported target: {target}") + + target_mac_range = target_mac_ranges[target] + if mac not in target_mac_range: + raise Exception( + f"Mac value for selected device should be in {target_mac_range}" + ) diff --git a/src/mlia/devices/ethosu/data_analysis.py b/src/mlia/devices/ethosu/data_analysis.py new file mode 100644 index 0000000..9ed32ff --- /dev/null +++ b/src/mlia/devices/ethosu/data_analysis.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Ethos-U data analysis module.""" +from dataclasses import dataclass +from functools import singledispatchmethod +from typing import Dict +from typing import List +from typing import Union + +from mlia.core.common import DataItem +from mlia.core.data_analysis import Fact +from mlia.core.data_analysis import FactExtractor +from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics +from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.tools.vela_wrapper import Operators + + +@dataclass +class HasCPUOnlyOperators(Fact): + """Model has CPU only operators.""" + + cpu_only_ops: List[str] + + +@dataclass +class HasUnsupportedOnNPUOperators(Fact): + """Model has unsupported on NPU operators.""" + + npu_unsupported_ratio: float + + +@dataclass +class AllOperatorsSupportedOnNPU(Fact): + """All model's operators supported on NPU.""" + + +@dataclass +class PerfMetricDiff: + """Performance metric difference.""" + + original_value: Union[int, float] + optimized_value: Union[int, float] + + @property + def diff(self) -> float: + """Difference between metrics.""" + if self.original_value == 0: + return 0 + + return 100 - ((self.optimized_value / self.original_value) * 100) + + @property + def improved(self) -> bool: + """Return true if metric improved.""" + return self.diff > 0 + + @property + def degraded(self) -> bool: + """Return true if metric degraded.""" + return self.diff < 0 + + @property + def same(self) -> bool: + """Return true if metric stays the same.""" + return self.diff == 0 + + +@dataclass +class OptimizationDiff: + """Optimization performance impact.""" + + opt_type: List[OptimizationSettings] + opt_diffs: Dict[str, PerfMetricDiff] + + +@dataclass +class OptimizationResults(Fact): + """Optimization results.""" + + diffs: List[OptimizationDiff] + + +class EthosUDataAnalyzer(FactExtractor): + """Ethos-U data analyzer.""" + + @singledispatchmethod + def analyze_data(self, data_item: DataItem) -> None: + """Analyse the data.""" + + @analyze_data.register + def analyze_operator_compatibility(self, operators: Operators) -> None: + """Analyse operator compatibility information.""" + cpu_only = [op.op_type for op in operators.ops if op.cpu_only] + if cpu_only: + self.add_fact(HasCPUOnlyOperators(cpu_only)) + + if operators.npu_unsupported_ratio != 0: + self.add_fact(HasUnsupportedOnNPUOperators(operators.npu_unsupported_ratio)) + + if operators.npu_unsupported_ratio == 0: + self.add_fact(AllOperatorsSupportedOnNPU()) + + @analyze_data.register + def analyze_optimization_results( + self, optimization_results: OptimizationPerformanceMetrics + ) -> None: + """Analyse optimization performance metrics.""" + optimizations = optimization_results.optimizations_perf_metrics + if not optimizations: + return + + orig = optimization_results.original_perf_metrics.in_kilobytes() + orig_memory = orig.memory_usage + orig_cycles = orig.npu_cycles + + diffs: List[OptimizationDiff] = [] + for opt_type, opt_perf_metrics in optimizations: + opt = opt_perf_metrics.in_kilobytes() + opt_memory = opt.memory_usage + opt_cycles = opt.npu_cycles + + opt_diffs: Dict[str, PerfMetricDiff] = {} + + if orig_memory and opt_memory: + opt_diffs.update( + { + "sram": PerfMetricDiff( + orig_memory.sram_memory_area_size, + opt_memory.sram_memory_area_size, + ), + "dram": PerfMetricDiff( + orig_memory.dram_memory_area_size, + opt_memory.dram_memory_area_size, + ), + "on_chip_flash": PerfMetricDiff( + orig_memory.on_chip_flash_memory_area_size, + opt_memory.on_chip_flash_memory_area_size, + ), + "off_chip_flash": PerfMetricDiff( + orig_memory.off_chip_flash_memory_area_size, + opt_memory.off_chip_flash_memory_area_size, + ), + } + ) + if orig_cycles and opt_cycles: + opt_diffs["npu_total_cycles"] = PerfMetricDiff( + orig_cycles.npu_total_cycles, + opt_cycles.npu_total_cycles, + ) + + diff = OptimizationDiff(opt_type=opt_type, opt_diffs=opt_diffs) + diffs.append(diff) + + self.add_fact(OptimizationResults(diffs)) diff --git a/src/mlia/devices/ethosu/data_collection.py b/src/mlia/devices/ethosu/data_collection.py new file mode 100644 index 0000000..291f1b8 --- /dev/null +++ b/src/mlia/devices/ethosu/data_collection.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Data collection module for Ethos-U.""" +import logging +from pathlib import Path +from typing import List +from typing import Optional + +from mlia.core.context import Context +from mlia.core.data_collection import ContextAwareDataCollector +from mlia.core.errors import FunctionalityNotSupportedError +from mlia.core.performance import estimate_performance +from mlia.devices.ethosu.config import EthosUConfiguration +from mlia.devices.ethosu.performance import EthosUPerformanceEstimator +from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics +from mlia.devices.ethosu.performance import PerformanceMetrics +from mlia.nn.tensorflow.config import get_keras_model +from mlia.nn.tensorflow.config import get_tflite_model +from mlia.nn.tensorflow.config import KerasModel +from mlia.nn.tensorflow.optimizations.select import get_optimizer +from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.nn.tensorflow.utils import save_keras_model +from mlia.tools.vela_wrapper import Operators +from mlia.tools.vela_wrapper import supported_operators +from mlia.utils.types import is_list_of + +logger = logging.getLogger(__name__) + + +class EthosUOperatorCompatibility(ContextAwareDataCollector): + """Collect operator compatibility information.""" + + def __init__(self, model: Path, device: EthosUConfiguration) -> None: + """Init operator compatibility data collector.""" + self.model = model + self.device = device + + def collect_data(self) -> Operators: + """Collect operator compatibility information.""" + tflite_model = get_tflite_model(self.model, self.context) + + logger.info("Checking operator compatibility ...") + ops = supported_operators( + Path(tflite_model.model_path), self.device.compiler_options + ) + logger.info("Done\n") + return ops + + @classmethod + def name(cls) -> str: + """Return name of the collector.""" + return "ethos_u_operator_compatibility" + + +class EthosUPerformance(ContextAwareDataCollector): + """Collect performance metrics.""" + + def __init__( + self, + model: Path, + device: EthosUConfiguration, + backends: Optional[List[str]] = None, + ) -> None: + """Init performance data collector.""" + self.model = model + self.device = device + self.backends = backends + + def collect_data(self) -> PerformanceMetrics: + """Collect model performance metrics.""" + tflite_model = get_tflite_model(self.model, self.context) + estimator = EthosUPerformanceEstimator( + self.context, + self.device, + self.backends, + ) + + return estimator.estimate(tflite_model) + + @classmethod + def name(cls) -> str: + """Return name of the collector.""" + return "ethos_u_performance" + + +class OptimizeModel: + """Helper class for model optimization.""" + + def __init__( + self, context: Context, opt_settings: List[OptimizationSettings] + ) -> None: + """Init helper.""" + self.context = context + self.opt_settings = opt_settings + + def __call__(self, keras_model: KerasModel) -> KerasModel: + """Run optimization.""" + optimizer = get_optimizer(keras_model, self.opt_settings) + + opts_as_str = ", ".join(str(opt) for opt in self.opt_settings) + logger.info("Applying model optimizations - [%s]", opts_as_str) + optimizer.apply_optimization() + + model = optimizer.get_model() + model_path = self.context.get_model_path("optimized_model.h5") + save_keras_model(model, model_path) + + return KerasModel(model_path) + + +class EthosUOptimizationPerformance(ContextAwareDataCollector): + """Collect performance metrics for the optimizations.""" + + def __init__( + self, + model: Path, + device: EthosUConfiguration, + optimizations: List[List[dict]], + backends: Optional[List[str]] = None, + ) -> None: + """Init performance optimizations data collector.""" + self.model = model + self.device = device + self.optimizations = optimizations + self.backends = backends + + def collect_data(self) -> Optional[OptimizationPerformanceMetrics]: + """Collect performance metrics for the optimizations.""" + logger.info("Estimate performance ...") + + if not self.optimizations: + raise FunctionalityNotSupportedError( + reason="Unable to estimate model optimizations impact", + description="No optimization targets provided", + ) + + opt_settings = self._parse_optimization_params(self.optimizations) + + try: + keras_model = get_keras_model(self.model, self.context) + except NotImplementedError as err: + raise FunctionalityNotSupportedError( + reason="Unable to run model optimizations", + description=f"{self.model} is not a Keras model and " + "could not be converted to a Keras model", + ) from err + + optimizers = [OptimizeModel(self.context, opts) for opts in opt_settings] + + estimator = EthosUPerformanceEstimator( + self.context, + self.device, + self.backends, + ) + original_metrics, *optimized_metrics = estimate_performance( + keras_model, estimator, optimizers # type: ignore + ) + + result = OptimizationPerformanceMetrics( + original_perf_metrics=original_metrics, + optimizations_perf_metrics=list(zip(opt_settings, optimized_metrics)), + ) + return result + + @staticmethod + def _parse_optimization_params( + optimizations: List[List[dict]], + ) -> List[List[OptimizationSettings]]: + """Parse optimization parameters.""" + if not is_list_of(optimizations, list): + raise Exception("Optimization parameters expected to be a list") + + return [ + [ + OptimizationSettings( + item.get("optimization_type"), # type: ignore + item.get("optimization_target"), # type: ignore + item.get("layers_to_optimized"), + ) + for item in opt_configuration + ] + for opt_configuration in optimizations + ] + + @classmethod + def name(cls) -> str: + """Return name of the collector.""" + return "ethos_u_model_optimizations" diff --git a/src/mlia/devices/ethosu/events.py b/src/mlia/devices/ethosu/events.py new file mode 100644 index 0000000..d5408b0 --- /dev/null +++ b/src/mlia/devices/ethosu/events.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Ethos-U MLIA module events.""" +from dataclasses import dataclass +from pathlib import Path + +from mlia.core.events import Event +from mlia.core.events import EventDispatcher +from mlia.devices.ethosu.config import EthosUConfiguration + + +@dataclass +class EthosUAdvisorStartedEvent(Event): + """Event with Ethos-U advisor parameters.""" + + model: Path + device: EthosUConfiguration + + +class EthosUAdvisorEventHandler(EventDispatcher): + """Event handler for the Ethos-U inference advisor.""" + + def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None: + """Handle EthosUAdvisorStarted event.""" diff --git a/src/mlia/devices/ethosu/handlers.py b/src/mlia/devices/ethosu/handlers.py new file mode 100644 index 0000000..7a0c31c --- /dev/null +++ b/src/mlia/devices/ethosu/handlers.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Event handler.""" +import logging +from pathlib import Path +from typing import Dict +from typing import List +from typing import Optional + +from mlia.core._typing import OutputFormat +from mlia.core._typing import PathOrFileLike +from mlia.core.advice_generation import Advice +from mlia.core.advice_generation import AdviceEvent +from mlia.core.events import AdviceStageFinishedEvent +from mlia.core.events import AdviceStageStartedEvent +from mlia.core.events import CollectedDataEvent +from mlia.core.events import DataAnalysisStageFinishedEvent +from mlia.core.events import DataCollectionStageStartedEvent +from mlia.core.events import DataCollectorSkippedEvent +from mlia.core.events import ExecutionFailedEvent +from mlia.core.events import ExecutionStartedEvent +from mlia.core.events import SystemEventsHandler +from mlia.core.reporting import Reporter +from mlia.devices.ethosu.events import EthosUAdvisorEventHandler +from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent +from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics +from mlia.devices.ethosu.performance import PerformanceMetrics +from mlia.devices.ethosu.reporters import find_appropriate_formatter +from mlia.tools.vela_wrapper import Operators +from mlia.utils.console import create_section_header + +logger = logging.getLogger(__name__) + +ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started") +MODEL_ANALYSIS_MSG = create_section_header("Model Analysis") +MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results") +ADV_GENERATION_MSG = create_section_header("Advice Generation") +REPORT_GENERATION_MSG = create_section_header("Report Generation") + + +class WorkflowEventsHandler(SystemEventsHandler): + """Event handler for the system events.""" + + def on_execution_started(self, event: ExecutionStartedEvent) -> None: + """Handle ExecutionStarted event.""" + logger.info(ADV_EXECUTION_STARTED) + + def on_execution_failed(self, event: ExecutionFailedEvent) -> None: + """Handle ExecutionFailed event.""" + raise event.err + + def on_data_collection_stage_started( + self, event: DataCollectionStageStartedEvent + ) -> None: + """Handle DataCollectionStageStarted event.""" + logger.info(MODEL_ANALYSIS_MSG) + + def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None: + """Handle AdviceStageStarted event.""" + logger.info(ADV_GENERATION_MSG) + + def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None: + """Handle DataCollectorSkipped event.""" + logger.info("Skipped: %s", event.reason) + + +class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler): + """CLI event handler.""" + + def __init__(self, output: Optional[PathOrFileLike] = None) -> None: + """Init event handler.""" + output_format = self.resolve_output_format(output) + + self.reporter = Reporter(find_appropriate_formatter, output_format) + self.output = output + self.advice: List[Advice] = [] + + def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None: + """Handle AdviceStageFinishedEvent event.""" + self.reporter.submit( + self.advice, + show_title=False, + show_headers=False, + space="between", + table_style="no_borders", + ) + + self.reporter.generate_report(self.output) + + if self.output is not None: + logger.info(REPORT_GENERATION_MSG) + logger.info("Report(s) and advice list saved to: %s", self.output) + + def on_data_analysis_stage_finished( + self, event: DataAnalysisStageFinishedEvent + ) -> None: + """Handle DataAnalysisStageFinished event.""" + logger.info(MODEL_ANALYSIS_RESULTS_MSG) + self.reporter.print_delayed() + + def on_collected_data(self, event: CollectedDataEvent) -> None: + """Handle CollectedDataEvent event.""" + data_item = event.data_item + + if isinstance(data_item, Operators): + self.reporter.submit([data_item.ops, data_item], delay_print=True) + + if isinstance(data_item, PerformanceMetrics): + self.reporter.submit(data_item, delay_print=True) + + if isinstance(data_item, OptimizationPerformanceMetrics): + original_metrics = data_item.original_perf_metrics + if not data_item.optimizations_perf_metrics: + return + + _opt_settings, optimized_metrics = data_item.optimizations_perf_metrics[0] + + self.reporter.submit( + [original_metrics, optimized_metrics], + delay_print=True, + columns_name="Metrics", + title="Performance metrics", + space=True, + ) + + def on_advice_event(self, event: AdviceEvent) -> None: + """Handle Advice event.""" + self.advice.append(event.advice) + + def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None: + """Handle EthosUAdvisorStarted event.""" + self.reporter.submit(event.device) + + @staticmethod + def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat: + """Resolve output format based on the output name.""" + output_format: OutputFormat = "plain_text" + + if isinstance(output, str): + output_path = Path(output) + output_formats: Dict[str, OutputFormat] = {".csv": "csv", ".json": "json"} + + if (suffix := output_path.suffix) in output_formats: + return output_formats[suffix] + + return output_format diff --git a/src/mlia/devices/ethosu/operators.py b/src/mlia/devices/ethosu/operators.py new file mode 100644 index 0000000..ff0d99f --- /dev/null +++ b/src/mlia/devices/ethosu/operators.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Operators module.""" +import logging + +from mlia.tools import vela_wrapper + + +logger = logging.getLogger(__name__) + + +def generate_supported_operators_report() -> None: + """Generate supported operators report.""" + vela_wrapper.generate_supported_operators_report() diff --git a/src/mlia/devices/ethosu/performance.py b/src/mlia/devices/ethosu/performance.py new file mode 100644 index 0000000..b0718a5 --- /dev/null +++ b/src/mlia/devices/ethosu/performance.py @@ -0,0 +1,257 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Performance estimation.""" +import logging +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import mlia.tools.aiet_wrapper as aiet +import mlia.tools.vela_wrapper as vela +from mlia.core.context import Context +from mlia.core.performance import PerformanceEstimator +from mlia.devices.ethosu.config import EthosUConfiguration +from mlia.nn.tensorflow.config import get_tflite_model +from mlia.nn.tensorflow.config import ModelConfiguration +from mlia.nn.tensorflow.optimizations.select import OptimizationSettings + + +logger = logging.getLogger(__name__) + + +@dataclass +class NPUCycles: + """NPU cycles metrics.""" + + npu_active_cycles: int + npu_idle_cycles: int + npu_total_cycles: int + npu_axi0_rd_data_beat_received: int + npu_axi0_wr_data_beat_written: int + npu_axi1_rd_data_beat_received: int + + +BYTES_PER_KILOBYTE = 1024 + + +class MemorySizeType(Enum): + """Memory size type enumeration.""" + + BYTES = 0 + KILOBYTES = 1 + + +@dataclass +class MemoryUsage: + """Memory usage metrics.""" + + sram_memory_area_size: Union[int, float] + dram_memory_area_size: Union[int, float] + unknown_memory_area_size: Union[int, float] + on_chip_flash_memory_area_size: Union[int, float] + off_chip_flash_memory_area_size: Union[int, float] + memory_size_type: MemorySizeType = MemorySizeType.BYTES + + _default_columns = [ + "SRAM used", + "DRAM used", + "Unknown memory used", + "On chip flash used", + "Off chip flash used", + ] + + def in_kilobytes(self) -> "MemoryUsage": + """Return memory usage with values in kilobytes.""" + if self.memory_size_type == MemorySizeType.KILOBYTES: + return self + + kilobytes = [ + value / BYTES_PER_KILOBYTE + for value in [ + self.sram_memory_area_size, + self.dram_memory_area_size, + self.unknown_memory_area_size, + self.on_chip_flash_memory_area_size, + self.off_chip_flash_memory_area_size, + ] + ] + + return MemoryUsage( + *kilobytes, # type: ignore + memory_size_type=MemorySizeType.KILOBYTES, + ) + + +@dataclass +class PerformanceMetrics: + """Performance metrics.""" + + device: EthosUConfiguration + npu_cycles: Optional[NPUCycles] + memory_usage: Optional[MemoryUsage] + + def in_kilobytes(self) -> "PerformanceMetrics": + """Return metrics with memory usage in KiB.""" + if self.memory_usage is None: + return PerformanceMetrics(self.device, self.npu_cycles, self.memory_usage) + + return PerformanceMetrics( + self.device, self.npu_cycles, self.memory_usage.in_kilobytes() + ) + + +@dataclass +class OptimizationPerformanceMetrics: + """Optimization performance metrics.""" + + original_perf_metrics: PerformanceMetrics + optimizations_perf_metrics: List[ + Tuple[List[OptimizationSettings], PerformanceMetrics] + ] + + +class VelaPerformanceEstimator( + PerformanceEstimator[Union[Path, ModelConfiguration], MemoryUsage] +): + """Vela based performance estimator.""" + + def __init__(self, context: Context, device: EthosUConfiguration) -> None: + """Init Vela based performance estimator.""" + self.context = context + self.device = device + + def estimate(self, model: Union[Path, ModelConfiguration]) -> MemoryUsage: + """Estimate performance.""" + logger.info("Getting the memory usage metrics ...") + + model_path = ( + Path(model.model_path) if isinstance(model, ModelConfiguration) else model + ) + + vela_perf_metrics = vela.estimate_performance( + model_path, self.device.compiler_options + ) + + memory_usage = MemoryUsage( + vela_perf_metrics.sram_memory_area_size, + vela_perf_metrics.dram_memory_area_size, + vela_perf_metrics.unknown_memory_area_size, + vela_perf_metrics.on_chip_flash_memory_area_size, + vela_perf_metrics.off_chip_flash_memory_area_size, + ) + logger.info("Done\n") + return memory_usage + + +class AIETPerformanceEstimator( + PerformanceEstimator[Union[Path, ModelConfiguration], NPUCycles] +): + """AIET based performance estimator.""" + + def __init__( + self, context: Context, device: EthosUConfiguration, backend: str + ) -> None: + """Init AIET based performance estimator.""" + self.context = context + self.device = device + self.backend = backend + + def estimate(self, model: Union[Path, ModelConfiguration]) -> NPUCycles: + """Estimate performance.""" + logger.info("Getting the performance metrics for '%s' ...", self.backend) + logger.info( + "WARNING: This task may require several minutes (press ctrl-c to interrupt)" + ) + + model_path = ( + Path(model.model_path) if isinstance(model, ModelConfiguration) else model + ) + + optimized_model_path = self.context.get_model_path( + f"{model_path.stem}_vela.tflite" + ) + + vela.optimize_model( + model_path, self.device.compiler_options, optimized_model_path + ) + + model_info = aiet.ModelInfo(model_path=optimized_model_path) + device_info = aiet.DeviceInfo( + device_type=self.device.target, # type: ignore + mac=self.device.mac, + memory_mode=self.device.compiler_options.memory_mode, # type: ignore + ) + + aiet_perf_metrics = aiet.estimate_performance( + model_info, device_info, self.backend + ) + + npu_cycles = NPUCycles( + aiet_perf_metrics.npu_active_cycles, + aiet_perf_metrics.npu_idle_cycles, + aiet_perf_metrics.npu_total_cycles, + aiet_perf_metrics.npu_axi0_rd_data_beat_received, + aiet_perf_metrics.npu_axi0_wr_data_beat_written, + aiet_perf_metrics.npu_axi1_rd_data_beat_received, + ) + + logger.info("Done\n") + return npu_cycles + + +class EthosUPerformanceEstimator( + PerformanceEstimator[Union[Path, ModelConfiguration], PerformanceMetrics] +): + """Ethos-U performance estimator.""" + + def __init__( + self, + context: Context, + device: EthosUConfiguration, + backends: Optional[List[str]] = None, + ) -> None: + """Init performance estimator.""" + self.context = context + self.device = device + if backends is None: + backends = ["Vela"] # Only Vela is always available as default + for backend in backends: + if backend != "Vela" and not aiet.is_supported(backend): + raise ValueError( + f"Unsupported backend '{backend}'. " + f"Only 'Vela' and {aiet.supported_backends()} are supported." + ) + self.backends = set(backends) + + def estimate(self, model: Union[Path, ModelConfiguration]) -> PerformanceMetrics: + """Estimate performance.""" + model_path = ( + Path(model.model_path) if isinstance(model, ModelConfiguration) else model + ) + + tflite_model = get_tflite_model(model_path, self.context) + + memory_usage = None + npu_cycles = None + + for backend in self.backends: + if backend == "Vela": + vela_estimator = VelaPerformanceEstimator(self.context, self.device) + memory_usage = vela_estimator.estimate(tflite_model) + elif backend in aiet.supported_backends(): + aiet_estimator = AIETPerformanceEstimator( + self.context, self.device, backend + ) + npu_cycles = aiet_estimator.estimate(tflite_model) + else: + logger.warning( + "Backend '%s' is not supported for Ethos-U performance " + "estimation.", + backend, + ) + + return PerformanceMetrics(self.device, npu_cycles, memory_usage) diff --git a/src/mlia/devices/ethosu/reporters.py b/src/mlia/devices/ethosu/reporters.py new file mode 100644 index 0000000..d28c68f --- /dev/null +++ b/src/mlia/devices/ethosu/reporters.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Reports module.""" +from collections import defaultdict +from typing import Any +from typing import Callable +from typing import List +from typing import Tuple +from typing import Union + +from mlia.core.advice_generation import Advice +from mlia.core.reporting import BytesCell +from mlia.core.reporting import Cell +from mlia.core.reporting import ClockCell +from mlia.core.reporting import Column +from mlia.core.reporting import CompoundFormatter +from mlia.core.reporting import CyclesCell +from mlia.core.reporting import Format +from mlia.core.reporting import NestedReport +from mlia.core.reporting import Report +from mlia.core.reporting import ReportItem +from mlia.core.reporting import SingleRow +from mlia.core.reporting import Table +from mlia.devices.ethosu.config import EthosUConfiguration +from mlia.devices.ethosu.performance import PerformanceMetrics +from mlia.tools.vela_wrapper import Operator +from mlia.tools.vela_wrapper import Operators +from mlia.utils.console import style_improvement +from mlia.utils.types import is_list_of + + +def report_operators_stat(operators: Operators) -> Report: + """Return table representation for the ops stats.""" + columns = [ + Column("Number of operators", alias="num_of_operators"), + Column("Number of NPU supported operators", "num_of_npu_supported_operators"), + Column("Unsupported ops ratio", "npu_unsupported_ratio"), + ] + rows = [ + ( + operators.total_number, + operators.npu_supported_number, + Cell( + operators.npu_unsupported_ratio * 100, + fmt=Format(str_fmt="{0:.0f}%".format), + ), + ) + ] + + return SingleRow( + columns, rows, name="Operators statistics", alias="operators_stats" + ) + + +def report_operators(ops: List[Operator]) -> Report: + """Return table representation for the list of operators.""" + columns = [ + Column("#", only_for=["plain_text"]), + Column( + "Operator name", + alias="operator_name", + fmt=Format(wrap_width=30), + ), + Column( + "Operator type", + alias="operator_type", + fmt=Format(wrap_width=25), + ), + Column( + "Placement", + alias="placement", + fmt=Format(wrap_width=20), + ), + Column( + "Notes", + alias="notes", + fmt=Format(wrap_width=35), + ), + ] + + rows = [ + ( + i + 1, + op.name, + op.op_type, + Cell( + "NPU" if (npu := op.run_on_npu.supported) else "CPU", + Format(style=style_improvement(npu)), + ), + Table( + columns=[ + Column( + "Note", + alias="note", + fmt=Format(wrap_width=35), + ) + ], + rows=[ + (Cell(item, Format(str_fmt=lambda x: f"* {x}")),) + for reason in op.run_on_npu.reasons + for item in reason + if item + ], + name="Notes", + ), + ) + for i, op in enumerate(ops) + ] + + return Table(columns, rows, name="Operators", alias="operators") + + +def report_device_details(device: EthosUConfiguration) -> Report: + """Return table representation for the device.""" + compiler_config = device.resolved_compiler_config + + memory_settings = [ + ReportItem( + "Const mem area", + "const_mem_area", + compiler_config["const_mem_area"], + ), + ReportItem( + "Arena mem area", + "arena_mem_area", + compiler_config["arena_mem_area"], + ), + ReportItem( + "Cache mem area", + "cache_mem_area", + compiler_config["cache_mem_area"], + ), + ReportItem( + "Arena cache size", + "arena_cache_size", + BytesCell(compiler_config["arena_cache_size"]), + ), + ] + + mem_areas_settings = [ + ReportItem( + f"{mem_area_name}", + mem_area_name, + None, + nested_items=[ + ReportItem( + "Clock scales", + "clock_scales", + mem_area_settings["clock_scales"], + ), + ReportItem( + "Burst length", + "burst_length", + BytesCell(mem_area_settings["burst_length"]), + ), + ReportItem( + "Read latency", + "read_latency", + CyclesCell(mem_area_settings["read_latency"]), + ), + ReportItem( + "Write latency", + "write_latency", + CyclesCell(mem_area_settings["write_latency"]), + ), + ], + ) + for mem_area_name, mem_area_settings in compiler_config["memory_area"].items() + ] + + system_settings = [ + ReportItem( + "Accelerator clock", + "accelerator_clock", + ClockCell(compiler_config["core_clock"]), + ), + ReportItem( + "AXI0 port", + "axi0_port", + compiler_config["axi0_port"], + ), + ReportItem( + "AXI1 port", + "axi1_port", + compiler_config["axi1_port"], + ), + ReportItem( + "Memory area settings", "memory_area", None, nested_items=mem_areas_settings + ), + ] + + arch_settings = [ + ReportItem( + "Permanent storage mem area", + "permanent_storage_mem_area", + compiler_config["permanent_storage_mem_area"], + ), + ReportItem( + "Feature map storage mem area", + "feature_map_storage_mem_area", + compiler_config["feature_map_storage_mem_area"], + ), + ReportItem( + "Fast storage mem area", + "fast_storage_mem_area", + compiler_config["fast_storage_mem_area"], + ), + ] + + return NestedReport( + "Device information", + "device", + [ + ReportItem("Target", alias="target", value=device.target), + ReportItem("MAC", alias="mac", value=device.mac), + ReportItem( + "Memory mode", + alias="memory_mode", + value=compiler_config["memory_mode"], + nested_items=memory_settings, + ), + ReportItem( + "System config", + alias="system_config", + value=compiler_config["system_config"], + nested_items=system_settings, + ), + ReportItem( + "Architecture settings", + "arch_settings", + None, + nested_items=arch_settings, + ), + ], + ) + + +def metrics_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]: + """Convert perf metrics object into list of records.""" + perf_metrics = [item.in_kilobytes() for item in perf_metrics] + + def _cycles_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]: + metric_map = defaultdict(list) + for metrics in perf_metrics: + if not metrics.npu_cycles: + return [] + metric_map["NPU active cycles"].append(metrics.npu_cycles.npu_active_cycles) + metric_map["NPU idle cycles"].append(metrics.npu_cycles.npu_idle_cycles) + metric_map["NPU total cycles"].append(metrics.npu_cycles.npu_total_cycles) + + return [ + (name, *(Cell(value, Format(str_fmt="12,d")) for value in values), "cycles") + for name, values in metric_map.items() + ] + + def _memory_usage_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]: + metric_map = defaultdict(list) + for metrics in perf_metrics: + if not metrics.memory_usage: + return [] + metric_map["SRAM used"].append(metrics.memory_usage.sram_memory_area_size) + metric_map["DRAM used"].append(metrics.memory_usage.dram_memory_area_size) + metric_map["Unknown memory area used"].append( + metrics.memory_usage.unknown_memory_area_size + ) + metric_map["On-chip flash used"].append( + metrics.memory_usage.on_chip_flash_memory_area_size + ) + metric_map["Off-chip flash used"].append( + metrics.memory_usage.off_chip_flash_memory_area_size + ) + + return [ + (name, *(Cell(value, Format(str_fmt="12.2f")) for value in values), "KiB") + for name, values in metric_map.items() + if all(val > 0 for val in values) + ] + + def _data_beats_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]: + metric_map = defaultdict(list) + for metrics in perf_metrics: + if not metrics.npu_cycles: + return [] + metric_map["NPU AXI0 RD data beat received"].append( + metrics.npu_cycles.npu_axi0_rd_data_beat_received + ) + metric_map["NPU AXI0 WR data beat written"].append( + metrics.npu_cycles.npu_axi0_wr_data_beat_written + ) + metric_map["NPU AXI1 RD data beat received"].append( + metrics.npu_cycles.npu_axi1_rd_data_beat_received + ) + + return [ + (name, *(Cell(value, Format(str_fmt="12,d")) for value in values), "beats") + for name, values in metric_map.items() + ] + + return [ + metrics + for metrics_func in ( + _memory_usage_as_records, + _cycles_as_records, + _data_beats_as_records, + ) + for metrics in metrics_func(perf_metrics) + ] + + +def report_perf_metrics( + perf_metrics: Union[PerformanceMetrics, List[PerformanceMetrics]] +) -> Report: + """Return comparison table for the performance metrics.""" + if isinstance(perf_metrics, PerformanceMetrics): + perf_metrics = [perf_metrics] + + rows = metrics_as_records(perf_metrics) + + if len(perf_metrics) == 2: + return Table( + columns=[ + Column("Metric", alias="metric", fmt=Format(wrap_width=30)), + Column("Original", alias="original", fmt=Format(wrap_width=15)), + Column("Optimized", alias="optimized", fmt=Format(wrap_width=15)), + Column("Unit", alias="unit", fmt=Format(wrap_width=15)), + Column("Improvement (%)", alias="improvement"), + ], + rows=[ + ( + metric, + original_value, + optimized_value, + unit, + Cell( + ( + diff := 100 + - (optimized_value.value / original_value.value * 100) + ), + Format(str_fmt="15.2f", style=style_improvement(diff > 0)), + ) + if original_value.value != 0 + else None, + ) + for metric, original_value, optimized_value, unit in rows + ], + name="Performance metrics", + alias="performance_metrics", + notes="IMPORTANT: The performance figures above refer to NPU only", + ) + + return Table( + columns=[ + Column("Metric", alias="metric", fmt=Format(wrap_width=30)), + Column("Value", alias="value", fmt=Format(wrap_width=15)), + Column("Unit", alias="unit", fmt=Format(wrap_width=15)), + ], + rows=rows, + name="Performance metrics", + alias="performance_metrics", + notes="IMPORTANT: The performance figures above refer to NPU only", + ) + + +def report_advice(advice: List[Advice]) -> Report: + """Generate report for the advice.""" + return Table( + columns=[ + Column("#", only_for=["plain_text"]), + Column("Advice", alias="advice_message"), + ], + rows=[(i + 1, a.messages) for i, a in enumerate(advice)], + name="Advice", + alias="advice", + ) + + +def find_appropriate_formatter(data: Any) -> Callable[[Any], Report]: + """Find appropriate formatter for the provided data.""" + if isinstance(data, PerformanceMetrics) or is_list_of(data, PerformanceMetrics, 2): + return report_perf_metrics + + if is_list_of(data, Advice): + return report_advice + + if is_list_of(data, Operator): + return report_operators + + if isinstance(data, Operators): + return report_operators_stat + + if isinstance(data, EthosUConfiguration): + return report_device_details + + if isinstance(data, (list, tuple)): + formatters = [find_appropriate_formatter(item) for item in data] + return CompoundFormatter(formatters) + + raise Exception(f"Unable to find appropriate formatter for {data}") diff --git a/src/mlia/nn/__init__.py b/src/mlia/nn/__init__.py new file mode 100644 index 0000000..aac2830 --- /dev/null +++ b/src/mlia/nn/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""NN related module.""" diff --git a/src/mlia/nn/tensorflow/__init__.py b/src/mlia/nn/tensorflow/__init__.py new file mode 100644 index 0000000..ff061c1 --- /dev/null +++ b/src/mlia/nn/tensorflow/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TensorFlow related module.""" diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py new file mode 100644 index 0000000..d3235d7 --- /dev/null +++ b/src/mlia/nn/tensorflow/config.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Model configuration.""" +import logging +from pathlib import Path +from typing import cast +from typing import Dict +from typing import List +from typing import Union + +import tensorflow as tf + +from mlia.core.context import Context +from mlia.nn.tensorflow.utils import convert_tf_to_tflite +from mlia.nn.tensorflow.utils import convert_to_tflite +from mlia.nn.tensorflow.utils import is_keras_model +from mlia.nn.tensorflow.utils import is_tf_model +from mlia.nn.tensorflow.utils import is_tflite_model +from mlia.nn.tensorflow.utils import save_tflite_model + +logger = logging.getLogger(__name__) + + +class ModelConfiguration: + """Base class for model configuration.""" + + def __init__(self, model_path: Union[str, Path]) -> None: + """Init model configuration instance.""" + self.model_path = str(model_path) + + def convert_to_tflite( + self, tflite_model_path: Union[str, Path], quantized: bool = False + ) -> "TFLiteModel": + """Convert model to TFLite format.""" + raise NotImplementedError() + + def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel": + """Convert model to Keras format.""" + raise NotImplementedError() + + +class KerasModel(ModelConfiguration): + """Keras model configuration. + + Supports all models supported by Keras API: saved model, H5, HDF5 + """ + + def get_keras_model(self) -> tf.keras.Model: + """Return associated Keras model.""" + return tf.keras.models.load_model(self.model_path) + + def convert_to_tflite( + self, tflite_model_path: Union[str, Path], quantized: bool = False + ) -> "TFLiteModel": + """Convert model to TFLite format.""" + logger.info("Converting Keras to TFLite ...") + + converted_model = convert_to_tflite(self.get_keras_model(), quantized) + logger.info("Done\n") + + save_tflite_model(converted_model, tflite_model_path) + logger.debug( + "Model %s converted and saved to %s", self.model_path, tflite_model_path + ) + + return TFLiteModel(tflite_model_path) + + def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel": + """Convert model to Keras format.""" + return self + + +class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method + """TFLite model configuration.""" + + def input_details(self) -> List[Dict]: + """Get model's input details.""" + interpreter = tf.lite.Interpreter(model_path=self.model_path) + return cast(List[Dict], interpreter.get_input_details()) + + def convert_to_tflite( + self, tflite_model_path: Union[str, Path], quantized: bool = False + ) -> "TFLiteModel": + """Convert model to TFLite format.""" + return self + + +class TfModel(ModelConfiguration): # pylint: disable=abstract-method + """TensorFlow model configuration. + + Supports models supported by TensorFlow API (not Keras) + """ + + def convert_to_tflite( + self, tflite_model_path: Union[str, Path], quantized: bool = False + ) -> "TFLiteModel": + """Convert model to TFLite format.""" + converted_model = convert_tf_to_tflite(self.model_path, quantized) + save_tflite_model(converted_model, tflite_model_path) + + return TFLiteModel(tflite_model_path) + + +def get_model(model: Union[Path, str]) -> "ModelConfiguration": + """Return the model object.""" + if is_tflite_model(model): + return TFLiteModel(model) + + if is_keras_model(model): + return KerasModel(model) + + if is_tf_model(model): + return TfModel(model) + + raise Exception( + "The input model format is not supported" + "(supported formats: TFLite, Keras, TensorFlow saved model)!" + ) + + +def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel": + """Convert input model to TFLite and returns TFLiteModel object.""" + tflite_model_path = ctx.get_model_path("converted_model.tflite") + converted_model = get_model(model) + + return converted_model.convert_to_tflite(tflite_model_path, True) + + +def get_keras_model(model: Union[str, Path], ctx: Context) -> "KerasModel": + """Convert input model to Keras and returns KerasModel object.""" + keras_model_path = ctx.get_model_path("converted_model.h5") + converted_model = get_model(model) + + return converted_model.convert_to_keras(keras_model_path) diff --git a/src/mlia/nn/tensorflow/optimizations/__init__.py b/src/mlia/nn/tensorflow/optimizations/__init__.py new file mode 100644 index 0000000..201c130 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Optimizations module.""" diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py new file mode 100644 index 0000000..16d9e4b --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/clustering.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +""" +Contains class Clusterer that clusters unique weights per layer to a specified number. + +In order to do this, we need to have a base model and corresponding training data. +We also have to specify a subset of layers we want to cluster. For more details, +please refer to the documentation for TensorFlow Model Optimization Toolkit. +""" +from dataclasses import dataclass +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import tensorflow as tf +import tensorflow_model_optimization as tfmot +from tensorflow_model_optimization.python.core.clustering.keras.experimental import ( # pylint: disable=no-name-in-module + cluster as experimental_cluster, +) + +from mlia.nn.tensorflow.optimizations.common import Optimizer +from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration + + +@dataclass +class ClusteringConfiguration(OptimizerConfiguration): + """Clustering configuration.""" + + optimization_target: int + layers_to_optimize: Optional[List[str]] = None + + def __str__(self) -> str: + """Return string representation of the configuration.""" + return f"clustering: {self.optimization_target}" + + +class Clusterer(Optimizer): + """ + Clusterer class. + + Used to cluster a model to a specified number of unique weights per layer. + + Sample usage: + clusterer = Clusterer( + base_model, + optimizer_configuration) + + clusterer.apply_clustering() + clustered_model = clusterer.get_model() + """ + + def __init__( + self, model: tf.keras.Model, optimizer_configuration: ClusteringConfiguration + ): + """Init Clusterer instance.""" + self.model = model + self.optimizer_configuration = optimizer_configuration + + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" + return str(self.optimizer_configuration) + + def _setup_clustering_params(self) -> Dict[str, Any]: + CentroidInitialization = tfmot.clustering.keras.CentroidInitialization + return { + "number_of_clusters": self.optimizer_configuration.optimization_target, + "cluster_centroids_init": CentroidInitialization.LINEAR, + "preserve_sparsity": True, + } + + def _apply_clustering_to_layer( + self, layer: tf.keras.layers.Layer + ) -> tf.keras.layers.Layer: + layers_to_optimize = self.optimizer_configuration.layers_to_optimize + assert layers_to_optimize, "List of the layers to optimize is empty" + + if layer.name not in layers_to_optimize: + return layer + + clustering_params = self._setup_clustering_params() + return experimental_cluster.cluster_weights(layer, **clustering_params) + + def _init_for_clustering(self) -> None: + # Use `tf.keras.models.clone_model` to apply `apply_clustering_to_layer` + # to the layers of the model + if not self.optimizer_configuration.layers_to_optimize: + clustering_params = self._setup_clustering_params() + clustered_model = experimental_cluster.cluster_weights( + self.model, **clustering_params + ) + else: + clustered_model = tf.keras.models.clone_model( + self.model, clone_function=self._apply_clustering_to_layer + ) + + self.model = clustered_model + + def _strip_clustering(self) -> None: + self.model = tfmot.clustering.keras.strip_clustering(self.model) + + def apply_optimization(self) -> None: + """Apply all steps of clustering at once.""" + self._init_for_clustering() + self._strip_clustering() + + def get_model(self) -> tf.keras.Model: + """Get model.""" + return self.model diff --git a/src/mlia/nn/tensorflow/optimizations/common.py b/src/mlia/nn/tensorflow/optimizations/common.py new file mode 100644 index 0000000..1dce0b2 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/common.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Common items for the optimizations module.""" +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass + +import tensorflow as tf + + +@dataclass +class OptimizerConfiguration: + """Abstract optimizer configuration.""" + + +class Optimizer(ABC): + """Abstract class for the optimizer.""" + + @abstractmethod + def get_model(self) -> tf.keras.Model: + """Abstract method to return the model instance from the optimizer.""" + + @abstractmethod + def apply_optimization(self) -> None: + """Abstract method to apply optimization to the model.""" + + @abstractmethod + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py new file mode 100644 index 0000000..f629ba1 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/pruning.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +""" +Contains class Pruner to prune a model to a specified sparsity. + +In order to do this, we need to have a base model and corresponding training data. +We also have to specify a subset of layers we want to prune. For more details, +please refer to the documentation for TensorFlow Model Optimization Toolkit. +""" +from dataclasses import dataclass +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import tensorflow as tf +import tensorflow_model_optimization as tfmot +from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module + pruning_wrapper, +) + +from mlia.nn.tensorflow.optimizations.common import Optimizer +from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration + + +@dataclass +class PruningConfiguration(OptimizerConfiguration): + """Pruning configuration.""" + + optimization_target: float + layers_to_optimize: Optional[List[str]] = None + x_train: Optional[np.array] = None + y_train: Optional[np.array] = None + batch_size: int = 1 + num_epochs: int = 1 + + def __str__(self) -> str: + """Return string representation of the configuration.""" + return f"pruning: {self.optimization_target}" + + def has_training_data(self) -> bool: + """Return True if training data provided.""" + return self.x_train is not None and self.y_train is not None + + +class Pruner(Optimizer): + """ + Pruner class. Used to prune a model to a specified sparsity. + + Sample usage: + pruner = Pruner( + base_model, + optimizer_configuration) + + pruner.apply_pruning() + pruned_model = pruner.get_model() + """ + + def __init__( + self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration + ): + """Init Pruner instance.""" + self.model = model + self.optimizer_configuration = optimizer_configuration + + if not optimizer_configuration.has_training_data(): + mock_x_train, mock_y_train = self._mock_train_data() + + self.optimizer_configuration.x_train = mock_x_train + self.optimizer_configuration.y_train = mock_y_train + + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" + return str(self.optimizer_configuration) + + def _mock_train_data(self) -> Tuple[np.array, np.array]: + # get rid of the batch_size dimension in input and output shape + input_shape = tuple(x for x in self.model.input_shape if x is not None) + output_shape = tuple(x for x in self.model.output_shape if x is not None) + + return ( + np.random.rand(*input_shape), + np.random.randint(0, output_shape[-1], (output_shape[:-1])), + ) + + def _setup_pruning_params(self) -> dict: + return { + "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay( + initial_sparsity=0, + final_sparsity=self.optimizer_configuration.optimization_target, + begin_step=0, + end_step=self.optimizer_configuration.num_epochs, + frequency=1, + ), + } + + def _apply_pruning_to_layer( + self, layer: tf.keras.layers.Layer + ) -> tf.keras.layers.Layer: + layers_to_optimize = self.optimizer_configuration.layers_to_optimize + assert layers_to_optimize, "List of the layers to optimize is empty" + + if layer.name not in layers_to_optimize: + return layer + + pruning_params = self._setup_pruning_params() + return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params) + + def _init_for_pruning(self) -> None: + # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer` + # to the layers of the model + if not self.optimizer_configuration.layers_to_optimize: + pruning_params = self._setup_pruning_params() + prunable_model = tfmot.sparsity.keras.prune_low_magnitude( + self.model, **pruning_params + ) + else: + prunable_model = tf.keras.models.clone_model( + self.model, clone_function=self._apply_pruning_to_layer + ) + + self.model = prunable_model + + def _train_pruning(self) -> None: + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy() + self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"]) + + # Model callbacks + callbacks = [tfmot.sparsity.keras.UpdatePruningStep()] + + # Fitting data + self.model.fit( + self.optimizer_configuration.x_train, + self.optimizer_configuration.y_train, + batch_size=self.optimizer_configuration.batch_size, + epochs=self.optimizer_configuration.num_epochs, + callbacks=callbacks, + verbose=0, + ) + + def _assert_sparsity_reached(self) -> None: + for layer in self.model.layers: + if not isinstance(layer, pruning_wrapper.PruneLowMagnitude): + continue + + for weight in layer.layer.get_prunable_weights(): + nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight)) + all_weights = tf.keras.backend.get_value(weight).size + + np.testing.assert_approx_equal( + self.optimizer_configuration.optimization_target, + 1 - nonzero_weights / all_weights, + significant=2, + ) + + def _strip_pruning(self) -> None: + self.model = tfmot.sparsity.keras.strip_pruning(self.model) + + def apply_optimization(self) -> None: + """Apply all steps of pruning sequentially.""" + self._init_for_pruning() + self._train_pruning() + self._assert_sparsity_reached() + self._strip_pruning() + + def get_model(self) -> tf.keras.Model: + """Get model.""" + return self.model diff --git a/src/mlia/nn/tensorflow/optimizations/select.py b/src/mlia/nn/tensorflow/optimizations/select.py new file mode 100644 index 0000000..1b0c755 --- /dev/null +++ b/src/mlia/nn/tensorflow/optimizations/select.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for optimization selection.""" +import math +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Tuple +from typing import Union + +import tensorflow as tf + +from mlia.core.errors import ConfigurationError +from mlia.nn.tensorflow.config import KerasModel +from mlia.nn.tensorflow.optimizations.clustering import Clusterer +from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration +from mlia.nn.tensorflow.optimizations.common import Optimizer +from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration +from mlia.nn.tensorflow.optimizations.pruning import Pruner +from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration +from mlia.utils.types import is_list_of + + +class OptimizationSettings(NamedTuple): + """Optimization settings.""" + + optimization_type: str + optimization_target: Union[int, float] + layers_to_optimize: Optional[List[str]] + + @staticmethod + def create_from( + optimizer_params: List[Tuple[str, float]], + layers_to_optimize: Optional[List[str]] = None, + ) -> List["OptimizationSettings"]: + """Create optimization settings from the provided parameters.""" + return [ + OptimizationSettings( + optimization_type=opt_type, + optimization_target=opt_target, + layers_to_optimize=layers_to_optimize, + ) + for opt_type, opt_target in optimizer_params + ] + + def __str__(self) -> str: + """Return string representation.""" + return f"{self.optimization_type}: {self.optimization_target}" + + def next_target(self) -> "OptimizationSettings": + """Return next optimization target.""" + if self.optimization_type == "pruning": + next_target = round(min(self.optimization_target + 0.1, 0.9), 2) + return OptimizationSettings( + self.optimization_type, next_target, self.layers_to_optimize + ) + + if self.optimization_type == "clustering": + # return next lowest power of two for clustering + next_target = math.log(self.optimization_target, 2) + if next_target.is_integer(): + next_target -= 1 + + next_target = max(int(2 ** int(next_target)), 4) + return OptimizationSettings( + self.optimization_type, next_target, self.layers_to_optimize + ) + + raise Exception(f"Unknown optimization type {self.optimization_type}") + + +class MultiStageOptimizer(Optimizer): + """Optimizer with multiply stages.""" + + def __init__( + self, + model: tf.keras.Model, + optimizations: List[OptimizerConfiguration], + ) -> None: + """Init MultiStageOptimizer instance.""" + self.model = model + self.optimizations = optimizations + + def optimization_config(self) -> str: + """Return string representation of the optimization config.""" + return " - ".join(str(opt) for opt in self.optimizations) + + def get_model(self) -> tf.keras.Model: + """Return optimized model.""" + return self.model + + def apply_optimization(self) -> None: + """Apply optimization to the model.""" + for config in self.optimizations: + optimizer = get_optimizer(self.model, config) + optimizer.apply_optimization() + self.model = optimizer.get_model() + + +def get_optimizer( + model: Union[tf.keras.Model, KerasModel], + config: Union[ + OptimizerConfiguration, OptimizationSettings, List[OptimizationSettings] + ], +) -> Optimizer: + """Get optimizer for provided configuration.""" + if isinstance(model, KerasModel): + model = model.get_keras_model() + + if isinstance(config, PruningConfiguration): + return Pruner(model, config) + + if isinstance(config, ClusteringConfiguration): + return Clusterer(model, config) + + if isinstance(config, OptimizationSettings) or is_list_of( + config, OptimizationSettings + ): + return _get_optimizer(model, config) # type: ignore + + raise ConfigurationError(f"Unknown optimization configuration {config}") + + +def _get_optimizer( + model: tf.keras.Model, + optimization_settings: Union[OptimizationSettings, List[OptimizationSettings]], +) -> Optimizer: + if isinstance(optimization_settings, OptimizationSettings): + optimization_settings = [optimization_settings] + + optimizer_configs = [] + for opt_type, opt_target, layers_to_optimize in optimization_settings: + _check_optimizer_params(opt_type, opt_target) + + opt_config = _get_optimizer_configuration( + opt_type, opt_target, layers_to_optimize + ) + optimizer_configs.append(opt_config) + + if len(optimizer_configs) == 1: + return get_optimizer(model, optimizer_configs[0]) + + return MultiStageOptimizer(model, optimizer_configs) + + +def _get_optimizer_configuration( + optimization_type: str, + optimization_target: Union[int, float], + layers_to_optimize: Optional[List[str]] = None, +) -> OptimizerConfiguration: + """Get optimizer configuration for provided parameters.""" + _check_optimizer_params(optimization_type, optimization_target) + + opt_type = optimization_type.lower() + if opt_type == "pruning": + return PruningConfiguration(optimization_target, layers_to_optimize) + + if opt_type == "clustering": + # make sure an integer is given as clustering target + if optimization_target == int(optimization_target): + return ClusteringConfiguration(int(optimization_target), layers_to_optimize) + + raise ConfigurationError( + "Optimization target should be a positive integer. " + f"Optimization target provided: {optimization_target}" + ) + + raise ConfigurationError(f"Unsupported optimization type: {optimization_type}") + + +def _check_optimizer_params( + optimization_type: str, optimization_target: Union[int, float] +) -> None: + """Check optimizer params.""" + if not optimization_target: + raise ConfigurationError("Optimization target is not provided") + + if not optimization_type: + raise ConfigurationError("Optimization type is not provided") diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py new file mode 100644 index 0000000..b29fab3 --- /dev/null +++ b/src/mlia/nn/tensorflow/tflite_metrics.py @@ -0,0 +1,296 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +""" +Contains class TFLiteMetrics to calculate metrics from a TFLite file. + +These metrics include: +* Sparsity (per layer and overall) +* Unique weights (clusters) (per layer) +* gzip compression ratio +""" +import os +from enum import Enum +from pprint import pprint +from typing import Any +from typing import List +from typing import Optional + +import numpy as np +import tensorflow as tf + +DEFAULT_IGNORE_LIST = [ + "relu", + "pooling", + "reshape", + "identity", + "input", + "add", + "flatten", + "StatefulPartitionedCall", + "bias", +] + + +def calculate_num_unique_weights(weights: np.array) -> int: + """Calculate the number of unique weights in the given weights.""" + num_unique_weights = len(np.unique(weights)) + return num_unique_weights + + +def calculate_num_unique_weights_per_axis(weights: np.array, axis: int) -> List[int]: + """Calculate unique weights per quantization axis.""" + # Make quantized dimension the first dimension + weights_trans = np.swapaxes(weights, 0, axis) + num_uniques_weights = [ + calculate_num_unique_weights(weights_trans[i]) + for i in range(weights_trans.shape[0]) + ] + assert num_uniques_weights + return num_uniques_weights + + +class SparsityAccumulator: + """Helper class to accumulate sparsity over several layers.""" + + def __init__(self) -> None: + """Create an empty accumulator.""" + self.total_non_zero_weights: int = 0 + self.total_weights: int = 0 + + def __call__(self, weights: np.array) -> None: + """Update the accumulator with the given weights.""" + non_zero_weights = np.count_nonzero(weights) + self.total_non_zero_weights += non_zero_weights + self.total_weights += weights.size + + def sparsity(self) -> float: + """Calculate the sparsity for all added weights.""" + return 1.0 - self.total_non_zero_weights / float(self.total_weights) + + +def calculate_sparsity( + weights: np.array, accumulator: Optional[SparsityAccumulator] = None +) -> float: + """ + Calculate the sparsity for the given weights. + + If the accumulator is passed, it is updated as well. + """ + non_zero_weights = np.count_nonzero(weights) + sparsity = 1.0 - float(non_zero_weights) / float(weights.size) + if accumulator is not None: + accumulator(weights) + return sparsity + + +class ReportClusterMode(Enum): + """Specifies the way cluster values are aggregated and reported.""" + + NUM_CLUSTERS_HISTOGRAM = ( + "A histogram of the number of clusters per axis. " + "I.e. the number of clusters is the index of the list (the bin) and " + "the value is the number of axes that have this number of clusters. " + "The first bin is 1." + ) + NUM_CLUSTERS_PER_AXIS = "Number of clusters (unique weights) per axis." + NUM_CLUSTERS_MIN_MAX = "Min/max number of clusters over all axes." + + +class TFLiteMetrics: + """Helper class to calculate metrics from a TFLite file. + + Metrics include: + * sparsity (per-layer and overall) + * number of unique weights (clusters) per layer + * File compression via gzip + """ + + def __init__( + self, tflite_file: str, ignore_list: Optional[List[str]] = None + ) -> None: + """Load the TFLite file and filter layers.""" + self.tflite_file = tflite_file + if ignore_list is None: + ignore_list = DEFAULT_IGNORE_LIST + self.ignore_list = [ignore.casefold() for ignore in ignore_list] + # Initialize the TFLite interpreter with the model file + self.interpreter = tf.lite.Interpreter(model_path=tflite_file) + self.interpreter.allocate_tensors() + self.details: dict = {} + + def ignore(details: dict) -> bool: + name = details["name"].casefold() + if not name: + return True + for to_ignore in self.ignore_list: + if to_ignore in name: + return True + return False + + self.filtered_details = { + details["name"]: details + for details in self.interpreter.get_tensor_details() + if not ignore(details) + } + + def get_tensor(self, details: dict) -> Any: + """Return the weights/tensor specified in the given details map.""" + return self.interpreter.tensor(details["index"])() + + def sparsity_per_layer(self) -> dict: + """Return a dict of layer name and sparsity value.""" + sparsity = { + name: calculate_sparsity(self.get_tensor(details)) + for name, details in self.filtered_details.items() + } + return sparsity + + def sparsity_overall(self) -> float: + """Return an instance of SparsityAccumulator for the filtered layers.""" + acc = SparsityAccumulator() + for details in self.filtered_details.values(): + acc(self.get_tensor(details)) + return acc.sparsity() + + def calc_num_clusters_per_axis(self, details: dict) -> List[int]: + """Calculate number of clusters per axis.""" + quant_params = details["quantization_parameters"] + per_axis = len(quant_params["zero_points"]) > 1 + if per_axis: + # Calculate unique weights along quantization axis + axis = quant_params["quantized_dimension"] + return calculate_num_unique_weights_per_axis(self.get_tensor(details), axis) + + # Calculate unique weights over all axes/dimensions + return [calculate_num_unique_weights(self.get_tensor(details))] + + def num_unique_weights(self, mode: ReportClusterMode) -> dict: + """Return a dict of layer name and number of unique weights.""" + aggregation_func = None + if mode == ReportClusterMode.NUM_CLUSTERS_PER_AXIS: + aggregation_func = self.calc_num_clusters_per_axis + elif mode == ReportClusterMode.NUM_CLUSTERS_MIN_MAX: + + def cluster_min_max(details: dict) -> List[int]: + num_clusters = self.calc_num_clusters_per_axis(details) + return [min(num_clusters), max(num_clusters)] + + aggregation_func = cluster_min_max + elif mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM: + + def cluster_hist(details: dict) -> List[int]: + num_clusters = self.calc_num_clusters_per_axis(details) + max_num = max(num_clusters) + hist = [0] * (max_num) + for num in num_clusters: + idx = num - 1 + hist[idx] += 1 + return hist + + aggregation_func = cluster_hist + else: + raise NotImplementedError( + "ReportClusterMode '{}' not implemented.".format(mode) + ) + uniques = { + name: aggregation_func(details) + for name, details in self.filtered_details.items() + } + return uniques + + @staticmethod + def _prettify_name(name: str) -> str: + if name.startswith("model"): + return name.split("/", 1)[1] + return name + + def summary( + self, + report_sparsity: bool, + report_cluster_mode: ReportClusterMode = None, + max_num_clusters: int = 32, + verbose: bool = False, + ) -> None: + """Print a summary of all the model information.""" + print("Model file: {}".format(self.tflite_file)) + print("#" * 80) + print(" " * 28 + "### TFLITE SUMMARY ###") + print("File: {}".format(os.path.abspath(self.tflite_file))) + print("Input(s):") + self._print_in_outs(self.interpreter.get_input_details(), verbose) + print("Output(s):") + self._print_in_outs(self.interpreter.get_output_details(), verbose) + print() + header = ["Layer", "Index", "Type", "Num weights"] + if report_sparsity: + header.append("Sparsity") + rows = [] + sparsity_accumulator = SparsityAccumulator() + for details in self.filtered_details.values(): + name = details["name"] + weights = self.get_tensor(details) + row = [ + self._prettify_name(name), + details["index"], + weights.dtype, + weights.size, + ] + if report_sparsity: + sparsity = calculate_sparsity(weights, sparsity_accumulator) + row.append("{:.2f}".format(sparsity)) + rows.append(row) + if verbose: + # Print cluster centroids + print("{} cluster centroids:".format(name)) + pprint(np.unique(weights)) + # Add summary/overall values + empty_row = ["" for _ in range(len(header))] + summary_row = empty_row + summary_row[header.index("Layer")] = "=> OVERALL" + summary_row[header.index("Num weights")] = str( + sparsity_accumulator.total_weights + ) + if report_sparsity: + summary_row[header.index("Sparsity")] = "{:.2f}".format( + sparsity_accumulator.sparsity() + ) + rows.append(summary_row) + # Report detailed cluster info + if report_cluster_mode is not None: + print() + self._print_cluster_details(report_cluster_mode, max_num_clusters) + print("#" * 80) + + def _print_cluster_details( + self, report_cluster_mode: ReportClusterMode, max_num_clusters: int + ) -> None: + print("{}:\n{}".format(report_cluster_mode.name, report_cluster_mode.value)) + num_clusters = self.num_unique_weights(report_cluster_mode) + if ( + report_cluster_mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM + and max_num_clusters > 0 + ): + # Only show cluster histogram if there are not more than + # max_num_clusters. This is a workaround for not showing a huge + # histogram for unclustered layers. + for name, value in num_clusters.items(): + if len(value) > max_num_clusters: + num_clusters[name] = "More than {} unique values.".format( + max_num_clusters + ) + for name, nums in num_clusters.items(): + print("- {}: {}".format(self._prettify_name(name), nums)) + + @staticmethod + def _print_in_outs(ios: List[dict], verbose: bool = False) -> None: + for item in ios: + if verbose: + pprint(item) + else: + print( + "- {} ({}): {}".format( + item["name"], + np.dtype(item["dtype"]).name, + item["shape"], + ) + ) diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py new file mode 100644 index 0000000..4abf6cd --- /dev/null +++ b/src/mlia/nn/tensorflow/utils.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Collection of useful functions for optimizations.""" +import logging +from pathlib import Path +from typing import Callable +from typing import Iterable +from typing import Union + +import numpy as np +import tensorflow as tf +from tensorflow.lite.python.interpreter import Interpreter + +from mlia.utils.logging import redirect_output + + +def representative_dataset(model: tf.keras.Model) -> Callable: + """Sample dataset used for quantization.""" + input_shape = model.input_shape + + def dataset() -> Iterable: + for _ in range(100): + if input_shape[0] != 1: + raise Exception("Only the input batch_size=1 is supported!") + data = np.random.rand(*input_shape) + yield [data.astype(np.float32)] + + return dataset + + +def get_tf_tensor_shape(model: str) -> list: + """Get input shape for the TensorFlow tensor model.""" + # Loading the model + loaded = tf.saved_model.load(model) + # The model signature must have 'serving_default' as a key + if "serving_default" not in loaded.signatures.keys(): + raise Exception( + "Unsupported TensorFlow model signature, must have 'serving_default'" + ) + # Get the signature inputs + inputs_tensor_info = loaded.signatures["serving_default"].inputs + dims = [] + # Build a list of all inputs shape sizes + for input_key in inputs_tensor_info: + if input_key.get_shape(): + dims.extend(list(input_key.get_shape())) + return dims + + +def representative_tf_dataset(model: str) -> Callable: + """Sample dataset used for quantization.""" + if not (input_shape := get_tf_tensor_shape(model)): + raise Exception("Unable to get input shape") + + def dataset() -> Iterable: + for _ in range(100): + data = np.random.rand(*input_shape) + yield [data.astype(np.float32)] + + return dataset + + +def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter: + """Convert Keras model to TFLite.""" + if not isinstance(model, tf.keras.Model): + raise Exception("Invalid model type") + + converter = tf.lite.TFLiteConverter.from_keras_model(model) + + if quantized: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset(model) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + + with redirect_output(logging.getLogger("tensorflow")): + tflite_model = converter.convert() + + return tflite_model + + +def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter: + """Convert TensorFlow model to TFLite.""" + if not isinstance(model, str): + raise Exception("Invalid model type") + + converter = tf.lite.TFLiteConverter.from_saved_model(model) + + if quantized: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_tf_dataset(model) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + + with redirect_output(logging.getLogger("tensorflow")): + tflite_model = converter.convert() + + return tflite_model + + +def save_keras_model(model: tf.keras.Model, save_path: Union[str, Path]) -> None: + """Save Keras model at provided path.""" + # Checkpoint: saving the optimizer is necessary. + model.save(save_path, include_optimizer=True) + + +def save_tflite_model( + model: tf.lite.TFLiteConverter, save_path: Union[str, Path] +) -> None: + """Save TFLite model at provided path.""" + with open(save_path, "wb") as file: + file.write(model) + + +def is_tflite_model(model: Union[Path, str]) -> bool: + """Check if model type is supported by TFLite API. + + TFLite model is indicated by the model file extension .tflite + """ + model_path = Path(model) + return model_path.suffix == ".tflite" + + +def is_keras_model(model: Union[Path, str]) -> bool: + """Check if model type is supported by Keras API. + + Keras model is indicated by: + 1. if it's a directory (meaning saved model), + it should contain keras_metadata.pb file + 2. or if the model file extension is .h5/.hdf5 + """ + model_path = Path(model) + + if model_path.is_dir(): + return (model_path / "keras_metadata.pb").exists() + return model_path.suffix in (".h5", ".hdf5") + + +def is_tf_model(model: Union[Path, str]) -> bool: + """Check if model type is supported by TensorFlow API. + + TensorFlow model is indicated if its directory (meaning saved model) + doesn't contain keras_metadata.pb file + """ + model_path = Path(model) + return model_path.is_dir() and not is_keras_model(model) diff --git a/src/mlia/resources/aiet/applications/APPLICATIONS.txt b/src/mlia/resources/aiet/applications/APPLICATIONS.txt new file mode 100644 index 0000000..09127f8 --- /dev/null +++ b/src/mlia/resources/aiet/applications/APPLICATIONS.txt @@ -0,0 +1,6 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +This directory contains the Generic Inference Runner application packages for AIET + +Each package should contain its own aiet-config.json file diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json new file mode 100644 index 0000000..757ccd1 --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json @@ -0,0 +1,18 @@ +[ + { + "name": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.", + "supported_systems": [ + { + "name": "Corstone-300: Cortex-M55+Ethos-U55" + }, + { + "name": "Corstone-300: Cortex-M55+Ethos-U65" + } + ], + "lock": true, + "variables": { + "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf" + } + } +] diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf new file mode 100644 index 0000000..4c50e1f Binary files /dev/null and b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf differ diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license new file mode 100644 index 0000000..8896f92 --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license @@ -0,0 +1,31 @@ +SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates. +SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano +SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation +SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved. +SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc +SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman +SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc. +SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about) +SPDX-FileCopyrightText: Copyright 2011, John Resig +SPDX-FileCopyrightText: Copyright 2016-2021 NXP +SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org +SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved. +SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch. +SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems +SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure +SPDX-FileCopyrightText: Copyright 2015 gRPC authors. +SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved. +SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch +SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors +SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors +SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc. +SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello. +SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar +SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved. +SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved. +SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved. +SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved. +SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors +SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json new file mode 100644 index 0000000..cb7e113 --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json @@ -0,0 +1,15 @@ +[ + { + "name": "Generic Inference Runner: Ethos-U55 SRAM", + "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.", + "supported_systems": [ + { + "name": "Corstone-300: Cortex-M55+Ethos-U55" + } + ], + "lock": true, + "variables": { + "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf" + } + } +] diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf new file mode 100644 index 0000000..850e2eb Binary files /dev/null and b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf differ diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license new file mode 100644 index 0000000..8896f92 --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license @@ -0,0 +1,31 @@ +SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates. +SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano +SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation +SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved. +SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc +SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman +SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc. +SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about) +SPDX-FileCopyrightText: Copyright 2011, John Resig +SPDX-FileCopyrightText: Copyright 2016-2021 NXP +SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org +SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved. +SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch. +SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems +SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure +SPDX-FileCopyrightText: Copyright 2015 gRPC authors. +SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved. +SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch +SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors +SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors +SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc. +SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello. +SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar +SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved. +SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved. +SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved. +SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved. +SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors +SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json new file mode 100644 index 0000000..d524f64 --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json @@ -0,0 +1,15 @@ +[ + { + "name": "Generic Inference Runner: Ethos-U65 Dedicated SRAM", + "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.", + "supported_systems": [ + { + "name": "Corstone-300: Cortex-M55+Ethos-U65" + } + ], + "lock": true, + "variables": { + "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf" + } + } +] diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf new file mode 100644 index 0000000..f881bb8 Binary files /dev/null and b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf differ diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license new file mode 100644 index 0000000..8896f92 --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license @@ -0,0 +1,31 @@ +SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates. +SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano +SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation +SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved. +SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc +SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman +SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc. +SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about) +SPDX-FileCopyrightText: Copyright 2011, John Resig +SPDX-FileCopyrightText: Copyright 2016-2021 NXP +SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org +SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved. +SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch. +SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems +SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure +SPDX-FileCopyrightText: Copyright 2015 gRPC authors. +SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved. +SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch +SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors +SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors +SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc. +SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello. +SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar +SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved. +SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved. +SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved. +SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved. +SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors +SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json new file mode 100644 index 0000000..2cbab70 --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json @@ -0,0 +1,15 @@ +[ + { + "name": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.", + "supported_systems": [ + { + "name": "Corstone-310: Cortex-M85+Ethos-U55" + } + ], + "lock": true, + "variables": { + "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf" + } + } +] diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf new file mode 100644 index 0000000..846ee33 Binary files /dev/null and b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf differ diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license new file mode 100644 index 0000000..8896f92 --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license @@ -0,0 +1,31 @@ +SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates. +SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano +SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation +SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved. +SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc +SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman +SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc. +SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about) +SPDX-FileCopyrightText: Copyright 2011, John Resig +SPDX-FileCopyrightText: Copyright 2016-2021 NXP +SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org +SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved. +SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch. +SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems +SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure +SPDX-FileCopyrightText: Copyright 2015 gRPC authors. +SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved. +SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch +SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors +SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors +SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc. +SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello. +SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar +SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved. +SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved. +SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved. +SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved. +SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors +SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json new file mode 100644 index 0000000..01bec74 --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json @@ -0,0 +1,15 @@ +[ + { + "name": "Generic Inference Runner: Ethos-U55 SRAM", + "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.", + "supported_systems": [ + { + "name": "Corstone-310: Cortex-M85+Ethos-U55" + } + ], + "lock": true, + "variables": { + "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf" + } + } +] diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf new file mode 100644 index 0000000..e3eab97 Binary files /dev/null and b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf differ diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license new file mode 100644 index 0000000..8896f92 --- /dev/null +++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license @@ -0,0 +1,31 @@ +SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates. +SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano +SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation +SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved. +SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc +SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman +SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc. +SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about) +SPDX-FileCopyrightText: Copyright 2011, John Resig +SPDX-FileCopyrightText: Copyright 2016-2021 NXP +SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org +SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved. +SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch. +SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems +SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure +SPDX-FileCopyrightText: Copyright 2015 gRPC authors. +SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved. +SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch +SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors +SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors +SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc. +SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello. +SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar +SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved. +SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved. +SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved. +SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved. +SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors +SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC diff --git a/src/mlia/resources/aiet/systems/SYSTEMS.txt b/src/mlia/resources/aiet/systems/SYSTEMS.txt new file mode 100644 index 0000000..bc27e73 --- /dev/null +++ b/src/mlia/resources/aiet/systems/SYSTEMS.txt @@ -0,0 +1,10 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +This directory contains the configuration files of the systems for the AIET +middleware. + +Supported systems: + + * FVP Corstone-300 Ecosystem + * FVP Corstone-310 Ecosystem diff --git a/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json b/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json new file mode 100644 index 0000000..3ffa548 --- /dev/null +++ b/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json @@ -0,0 +1,80 @@ +[ + { + "name": "Corstone-300: Cortex-M55+Ethos-U55", + "description": "Cortex-M55 and Ethos-U55 functional model implementations based on Corstone-300 design for MPS3 board.", + "annotations": { + "ip_class": "Ethos-U55", + "sim_type": "FM", + "variant": "Cortex-M55+Ethos-U55" + }, + "data_transfer": { + "protocol": "local" + }, + "lock": true, + "commands": { + "run": [ + "/opt/VHT/VHT_Corstone_SSE-300_Ethos-U55 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat" + ] + }, + "user_params": { + "run": [ + { + "name": "--data", + "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.", + "values": [], + "alias": "input_file" + }, + { + "name": "ethosu.num_macs=", + "description": "Arm Ethos-U55 configuration - the number represents MACs per cycle.", + "values": [ + "32", + "64", + "128", + "256" + ], + "default_value": "256", + "alias": "mac" + } + ] + } + }, + { + "name": "Corstone-300: Cortex-M55+Ethos-U65", + "description": "Cortex-M55 and Ethos-U65 functional model implementations based on Corstone-300 design for MPS3 board.", + "annotations": { + "ip_class": "Ethos-U65", + "sim_type": "FM", + "variant": "Cortex-M55+Ethos-U65" + }, + "data_transfer": { + "protocol": "local" + }, + "lock": true, + "commands": { + "run": [ + "/opt/VHT/VHT_Corstone_SSE-300_Ethos-U65 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat" + ] + }, + "user_params": { + "run": [ + { + "name": "--data", + "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.", + "values": [], + "alias": "input_file" + }, + { + "name": "ethosu.num_macs=", + "description": "Arm Ethos-U65 configuration - the number represents MACs per cycle.", + "values": [ + "256", + "512" + ], + "default_value": "512", + "alias": "mac" + } + ] + } + } +] diff --git a/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license b/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json b/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json new file mode 100644 index 0000000..6d6785d --- /dev/null +++ b/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json @@ -0,0 +1,80 @@ +[ + { + "name": "Corstone-300: Cortex-M55+Ethos-U55", + "description": "Cortex-M55 and Ethos-U55 functional model implementations based on Corstone-300 design for MPS3 board.", + "annotations": { + "ip_class": "Ethos-U55", + "sim_type": "FM", + "variant": "Cortex-M55+Ethos-U55" + }, + "data_transfer": { + "protocol": "local" + }, + "lock": true, + "commands": { + "run": [ + "FVP_Corstone_SSE-300_Ethos-U55 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat" + ] + }, + "user_params": { + "run": [ + { + "name": "--data", + "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.", + "values": [], + "alias": "input_file" + }, + { + "name": "ethosu.num_macs=", + "description": "Arm Ethos-U55 configuration - the number represents MACs per cycle.", + "values": [ + "32", + "64", + "128", + "256" + ], + "default_value": "256", + "alias": "mac" + } + ] + } + }, + { + "name": "Corstone-300: Cortex-M55+Ethos-U65", + "description": "Cortex-M55 and Ethos-U65 functional model implementations based on Corstone-300 design for MPS3 board.", + "annotations": { + "ip_class": "Ethos-U65", + "sim_type": "FM", + "variant": "Cortex-M55+Ethos-U65" + }, + "data_transfer": { + "protocol": "local" + }, + "lock": true, + "commands": { + "run": [ + "FVP_Corstone_SSE-300_Ethos-U65 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat" + ] + }, + "user_params": { + "run": [ + { + "name": "--data", + "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.", + "values": [], + "alias": "input_file" + }, + { + "name": "ethosu.num_macs=", + "description": "Arm Ethos-U65 configuration - the number represents MACs per cycle.", + "values": [ + "256", + "512" + ], + "default_value": "512", + "alias": "mac" + } + ] + } + } +] diff --git a/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license b/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json b/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json new file mode 100644 index 0000000..dbc2622 --- /dev/null +++ b/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json @@ -0,0 +1,42 @@ +[ + { + "name": "Corstone-310: Cortex-M85+Ethos-U55", + "description": "Cortex-M85 and Ethos-U55 functional model implementations based on Corstone-310 design for MPS3 board.", + "annotations": { + "ip_class": "Ethos-U55", + "sim_type": "FM", + "variant": "Cortex-M85+Ethos-U55" + }, + "data_transfer": { + "protocol": "local" + }, + "lock": true, + "commands": { + "run": [ + "/opt/VHT/VHT_Corstone_SSE-310 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat" + ] + }, + "user_params": { + "run": [ + { + "name": "--data", + "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.", + "values": [], + "alias": "input_file" + }, + { + "name": "ethosu.num_macs=", + "description": "Arm Ethos-U55 configuration - the number represents MACs per cycle.", + "values": [ + "32", + "64", + "128", + "256" + ], + "default_value": "256", + "alias": "mac" + } + ] + } + } +] diff --git a/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license b/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json b/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json new file mode 100644 index 0000000..7aa3b0a --- /dev/null +++ b/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json @@ -0,0 +1,42 @@ +[ + { + "name": "Corstone-310: Cortex-M85+Ethos-U55", + "description": "Cortex-M85 and Ethos-U55 functional model implementations based on Corstone-310 design for MPS3 board.", + "annotations": { + "ip_class": "Ethos-U55", + "sim_type": "FM", + "variant": "Cortex-M85+Ethos-U55" + }, + "data_transfer": { + "protocol": "local" + }, + "lock": true, + "commands": { + "run": [ + "FVP_Corstone_SSE-310 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat" + ] + }, + "user_params": { + "run": [ + { + "name": "--data", + "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.", + "values": [], + "alias": "input_file" + }, + { + "name": "ethosu.num_macs=", + "description": "Arm Ethos-U55 configuration - the number represents MACs per cycle.", + "values": [ + "32", + "64", + "128", + "256" + ], + "default_value": "256", + "alias": "mac" + } + ] + } + } +] diff --git a/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license b/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/resources/profiles.json b/src/mlia/resources/profiles.json new file mode 100644 index 0000000..4493d7b --- /dev/null +++ b/src/mlia/resources/profiles.json @@ -0,0 +1,20 @@ +{ + "ethos-u55-256": { + "target": "ethos-u55", + "mac": 256, + "system_config": "Ethos_U55_High_End_Embedded", + "memory_mode": "Shared_Sram" + }, + "ethos-u55-128": { + "target": "ethos-u55", + "mac": 128, + "system_config": "Ethos_U55_High_End_Embedded", + "memory_mode": "Shared_Sram" + }, + "ethos-u65-512": { + "target": "ethos-u65", + "mac": 512, + "system_config": "Ethos_U65_High_End", + "memory_mode": "Dedicated_Sram" + } +} diff --git a/src/mlia/resources/profiles.json.license b/src/mlia/resources/profiles.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/src/mlia/resources/profiles.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/src/mlia/resources/vela/vela.ini b/src/mlia/resources/vela/vela.ini new file mode 100644 index 0000000..382820d --- /dev/null +++ b/src/mlia/resources/vela/vela.ini @@ -0,0 +1,75 @@ +; SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates. +; SPDX-License-Identifier: Apache-2.0 + +; ----------------------------------------------------------------------------- +; Vela configuration file +; ----------------------------------------------------------------------------- + +; System Configuration + +; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s) +[System_Config.Ethos_U55_High_End_Embedded] +core_clock=500e6 +axi0_port=Sram +axi1_port=OffChipFlash +Sram_clock_scale=1.0 +Sram_burst_length=32 +Sram_read_latency=32 +Sram_write_latency=32 +OffChipFlash_clock_scale=0.125 +OffChipFlash_burst_length=128 +OffChipFlash_read_latency=64 +OffChipFlash_write_latency=64 + +; Ethos-U65 Embedded: SRAM (8 GB/s) and Flash (0.5 GB/s) +[System_Config.Ethos_U65_Embedded] +core_clock=500e6 +axi0_port=Sram +axi1_port=OffChipFlash +Sram_clock_scale=1.0 +Sram_burst_length=32 +Sram_read_latency=32 +Sram_write_latency=32 +OffChipFlash_clock_scale=0.0625 +OffChipFlash_burst_length=128 +OffChipFlash_read_latency=64 +OffChipFlash_write_latency=64 + +; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s) +[System_Config.Ethos_U65_High_End] +core_clock=1e9 +axi0_port=Sram +axi1_port=Dram +Sram_clock_scale=1.0 +Sram_burst_length=32 +Sram_read_latency=32 +Sram_write_latency=32 +Dram_clock_scale=0.234375 +Dram_burst_length=128 +Dram_read_latency=500 +Dram_write_latency=250 + +; ----------------------------------------------------------------------------- + +; Memory Mode + +; SRAM Only: only one AXI port is used and the SRAM is used for all storage +[Memory_Mode.Sram_Only] +const_mem_area=Axi0 +arena_mem_area=Axi0 +cache_mem_area=Axi0 + +; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software +; The non-SRAM memory is assumed to be read-only +[Memory_Mode.Shared_Sram] +const_mem_area=Axi1 +arena_mem_area=Axi0 +cache_mem_area=Axi0 + +; Dedicated SRAM: the SRAM (384KB) is only for use by the Ethos-U +; The non-SRAM memory is assumed to be read-writeable +[Memory_Mode.Dedicated_Sram] +const_mem_area=Axi1 +arena_mem_area=Axi1 +cache_mem_area=Axi0 +arena_cache_size=393216 diff --git a/src/mlia/tools/__init__.py b/src/mlia/tools/__init__.py new file mode 100644 index 0000000..184e966 --- /dev/null +++ b/src/mlia/tools/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tools module.""" diff --git a/src/mlia/tools/aiet_wrapper.py b/src/mlia/tools/aiet_wrapper.py new file mode 100644 index 0000000..73e82ee --- /dev/null +++ b/src/mlia/tools/aiet_wrapper.py @@ -0,0 +1,435 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for AIET integration.""" +import logging +import re +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import Tuple + +from aiet.backend.application import get_available_applications +from aiet.backend.application import install_application +from aiet.backend.system import get_available_systems +from aiet.backend.system import install_system +from mlia.utils.proc import CommandExecutor +from mlia.utils.proc import OutputConsumer +from mlia.utils.proc import RunningCommand + + +logger = logging.getLogger(__name__) + +# Mapping backend -> device_type -> system_name +_SUPPORTED_SYSTEMS = { + "Corstone-300": { + "ethos-u55": "Corstone-300: Cortex-M55+Ethos-U55", + "ethos-u65": "Corstone-300: Cortex-M55+Ethos-U65", + }, + "Corstone-310": { + "ethos-u55": "Corstone-310: Cortex-M85+Ethos-U55", + }, +} + +# Mapping system_name -> memory_mode -> application +_SYSTEM_TO_APP_MAP = { + "Corstone-300: Cortex-M55+Ethos-U55": { + "Sram": "Generic Inference Runner: Ethos-U55 SRAM", + "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + }, + "Corstone-300: Cortex-M55+Ethos-U65": { + "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + "Dedicated_Sram": "Generic Inference Runner: Ethos-U65 Dedicated SRAM", + }, + "Corstone-310: Cortex-M85+Ethos-U55": { + "Sram": "Generic Inference Runner: Ethos-U55 SRAM", + "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + }, +} + + +def get_system_name(backend: str, device_type: str) -> str: + """Get the AIET system name for the given backend and device type.""" + return _SUPPORTED_SYSTEMS[backend][device_type] + + +def is_supported(backend: str, device_type: Optional[str] = None) -> bool: + """Check if the backend (and optionally device type) is supported.""" + if device_type is None: + return backend in _SUPPORTED_SYSTEMS + + try: + get_system_name(backend, device_type) + return True + except KeyError: + return False + + +def supported_backends() -> List[str]: + """Get a list of all backends supported by the AIET wrapper.""" + return list(_SUPPORTED_SYSTEMS.keys()) + + +def get_all_system_names(backend: str) -> List[str]: + """Get all systems supported by the backend.""" + return list(_SUPPORTED_SYSTEMS.get(backend, {}).values()) + + +def get_all_application_names(backend: str) -> List[str]: + """Get all applications supported by the backend.""" + app_set = { + app + for sys in get_all_system_names(backend) + for app in _SYSTEM_TO_APP_MAP[sys].values() + } + return list(app_set) + + +@dataclass +class DeviceInfo: + """Device information.""" + + device_type: Literal["ethos-u55", "ethos-u65"] + mac: int + memory_mode: Literal["Sram", "Shared_Sram", "Dedicated_Sram"] + + +@dataclass +class ModelInfo: + """Model info.""" + + model_path: Path + + +@dataclass +class PerformanceMetrics: + """Performance metrics parsed from generic inference output.""" + + npu_active_cycles: int + npu_idle_cycles: int + npu_total_cycles: int + npu_axi0_rd_data_beat_received: int + npu_axi0_wr_data_beat_written: int + npu_axi1_rd_data_beat_received: int + + +@dataclass +class ExecutionParams: + """Application execution params.""" + + application: str + system: str + application_params: List[str] + system_params: List[str] + deploy_params: List[str] + + +class AIETLogWriter(OutputConsumer): + """Redirect AIET command output to the logger.""" + + def feed(self, line: str) -> None: + """Process line from the output.""" + logger.debug(line.strip()) + + +class GenericInferenceOutputParser(OutputConsumer): + """Generic inference app output parser.""" + + PATTERNS = { + name: tuple(re.compile(pattern, re.IGNORECASE) for pattern in patterns) + for name, patterns in ( + ( + "npu_active_cycles", + ( + r"NPU ACTIVE cycles: (?P\d+)", + r"NPU ACTIVE: (?P\d+) cycles", + ), + ), + ( + "npu_idle_cycles", + ( + r"NPU IDLE cycles: (?P\d+)", + r"NPU IDLE: (?P\d+) cycles", + ), + ), + ( + "npu_total_cycles", + ( + r"NPU TOTAL cycles: (?P\d+)", + r"NPU TOTAL: (?P\d+) cycles", + ), + ), + ( + "npu_axi0_rd_data_beat_received", + ( + r"NPU AXI0_RD_DATA_BEAT_RECEIVED beats: (?P\d+)", + r"NPU AXI0_RD_DATA_BEAT_RECEIVED: (?P\d+) beats", + ), + ), + ( + "npu_axi0_wr_data_beat_written", + ( + r"NPU AXI0_WR_DATA_BEAT_WRITTEN beats: (?P\d+)", + r"NPU AXI0_WR_DATA_BEAT_WRITTEN: (?P\d+) beats", + ), + ), + ( + "npu_axi1_rd_data_beat_received", + ( + r"NPU AXI1_RD_DATA_BEAT_RECEIVED beats: (?P\d+)", + r"NPU AXI1_RD_DATA_BEAT_RECEIVED: (?P\d+) beats", + ), + ), + ) + } + + def __init__(self) -> None: + """Init generic inference output parser instance.""" + self.result: Dict = {} + + def feed(self, line: str) -> None: + """Feed new line to the parser.""" + for name, patterns in self.PATTERNS.items(): + for pattern in patterns: + match = pattern.search(line) + + if match: + self.result[name] = int(match["value"]) + return + + def is_ready(self) -> bool: + """Return true if all expected data has been parsed.""" + return self.result.keys() == self.PATTERNS.keys() + + def missed_keys(self) -> List[str]: + """Return list of the keys that have not been found in the output.""" + return sorted(self.PATTERNS.keys() - self.result.keys()) + + +class AIETRunner: + """AIET runner.""" + + def __init__(self, executor: CommandExecutor) -> None: + """Init AIET runner instance.""" + self.executor = executor + + @staticmethod + def get_installed_systems() -> List[str]: + """Get list of the installed systems.""" + return [system.name for system in get_available_systems()] + + @staticmethod + def get_installed_applications(system: Optional[str] = None) -> List[str]: + """Get list of the installed application.""" + return [ + app.name + for app in get_available_applications() + if system is None or app.can_run_on(system) + ] + + def is_application_installed(self, application: str, system: str) -> bool: + """Return true if requested application installed.""" + return application in self.get_installed_applications(system) + + def is_system_installed(self, system: str) -> bool: + """Return true if requested system installed.""" + return system in self.get_installed_systems() + + def systems_installed(self, systems: List[str]) -> bool: + """Check if all provided systems are installed.""" + if not systems: + return False + + installed_systems = self.get_installed_systems() + return all(system in installed_systems for system in systems) + + def applications_installed(self, applications: List[str]) -> bool: + """Check if all provided applications are installed.""" + if not applications: + return False + + installed_apps = self.get_installed_applications() + return all(app in installed_apps for app in applications) + + def all_installed(self, systems: List[str], apps: List[str]) -> bool: + """Check if all provided artifacts are installed.""" + return self.systems_installed(systems) and self.applications_installed(apps) + + @staticmethod + def install_system(system_path: Path) -> None: + """Install system.""" + install_system(system_path) + + @staticmethod + def install_application(app_path: Path) -> None: + """Install application.""" + install_application(app_path) + + def run_application(self, execution_params: ExecutionParams) -> RunningCommand: + """Run requested application.""" + command = [ + "aiet", + "application", + "run", + "-n", + execution_params.application, + "-s", + execution_params.system, + *self._params("-p", execution_params.application_params), + *self._params("--system-param", execution_params.system_params), + *self._params("--deploy", execution_params.deploy_params), + ] + + return self._submit(command) + + @staticmethod + def _params(name: str, params: List[str]) -> List[str]: + return [p for item in [(name, param) for param in params] for p in item] + + def _submit(self, command: List[str]) -> RunningCommand: + """Submit command for the execution.""" + logger.debug("Submit command %s", " ".join(command)) + return self.executor.submit(command) + + +class GenericInferenceRunner(ABC): + """Abstract class for generic inference runner.""" + + def __init__(self, aiet_runner: AIETRunner): + """Init generic inference runner instance.""" + self.aiet_runner = aiet_runner + self.running_inference: Optional[RunningCommand] = None + + def run( + self, model_info: ModelInfo, output_consumers: List[OutputConsumer] + ) -> None: + """Run generic inference for the provided device/model.""" + execution_params = self.get_execution_params(model_info) + + self.running_inference = self.aiet_runner.run_application(execution_params) + self.running_inference.output_consumers = output_consumers + self.running_inference.consume_output() + + def stop(self) -> None: + """Stop running inference.""" + if self.running_inference is None: + return + + self.running_inference.stop() + + @abstractmethod + def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams: + """Get execution params for the provided model.""" + + def __enter__(self) -> "GenericInferenceRunner": + """Enter context.""" + return self + + def __exit__(self, *_args: Any) -> None: + """Exit context.""" + self.stop() + + def check_system_and_application(self, system_name: str, app_name: str) -> None: + """Check if requested system and application installed.""" + if not self.aiet_runner.is_system_installed(system_name): + raise Exception(f"System {system_name} is not installed") + + if not self.aiet_runner.is_application_installed(app_name, system_name): + raise Exception( + f"Application {app_name} for the system {system_name} " + "is not installed" + ) + + +class GenericInferenceRunnerEthosU(GenericInferenceRunner): + """Generic inference runner on U55/65.""" + + def __init__( + self, aiet_runner: AIETRunner, device_info: DeviceInfo, backend: str + ) -> None: + """Init generic inference runner instance.""" + super().__init__(aiet_runner) + + system_name, app_name = self.resolve_system_and_app(device_info, backend) + self.system_name = system_name + self.app_name = app_name + self.device_info = device_info + + @staticmethod + def resolve_system_and_app( + device_info: DeviceInfo, backend: str + ) -> Tuple[str, str]: + """Find appropriate system and application for the provided device/backend.""" + try: + system_name = get_system_name(backend, device_info.device_type) + except KeyError as ex: + raise RuntimeError( + f"Unsupported device {device_info.device_type} " + f"for backend {backend}" + ) from ex + + if system_name not in _SYSTEM_TO_APP_MAP: + raise RuntimeError(f"System {system_name} is not installed") + + try: + app_name = _SYSTEM_TO_APP_MAP[system_name][device_info.memory_mode] + except KeyError as err: + raise RuntimeError( + f"Unsupported memory mode {device_info.memory_mode}" + ) from err + + return system_name, app_name + + def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams: + """Get execution params for Ethos-U55/65.""" + self.check_system_and_application(self.system_name, self.app_name) + + system_params = [ + f"mac={self.device_info.mac}", + f"input_file={model_info.model_path.absolute()}", + ] + + return ExecutionParams( + self.app_name, + self.system_name, + [], + system_params, + [], + ) + + +def get_generic_runner(device_info: DeviceInfo, backend: str) -> GenericInferenceRunner: + """Get generic runner for provided device and backend.""" + aiet_runner = get_aiet_runner() + return GenericInferenceRunnerEthosU(aiet_runner, device_info, backend) + + +def estimate_performance( + model_info: ModelInfo, device_info: DeviceInfo, backend: str +) -> PerformanceMetrics: + """Get performance estimations.""" + with get_generic_runner(device_info, backend) as generic_runner: + output_parser = GenericInferenceOutputParser() + output_consumers = [output_parser, AIETLogWriter()] + + generic_runner.run(model_info, output_consumers) + + if not output_parser.is_ready(): + missed_data = ",".join(output_parser.missed_keys()) + logger.debug( + "Unable to get performance metrics, missed data %s", missed_data + ) + raise Exception("Unable to get performance metrics, insufficient data") + + return PerformanceMetrics(**output_parser.result) + + +def get_aiet_runner() -> AIETRunner: + """Return AIET runner.""" + executor = CommandExecutor() + return AIETRunner(executor) diff --git a/src/mlia/tools/metadata/__init__.py b/src/mlia/tools/metadata/__init__.py new file mode 100644 index 0000000..f877e4f --- /dev/null +++ b/src/mlia/tools/metadata/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for the tools metadata.""" diff --git a/src/mlia/tools/metadata/common.py b/src/mlia/tools/metadata/common.py new file mode 100644 index 0000000..c17a738 --- /dev/null +++ b/src/mlia/tools/metadata/common.py @@ -0,0 +1,290 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for installation process.""" +import logging +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Callable +from typing import List +from typing import Optional +from typing import Union + +from mlia.utils.misc import yes + + +logger = logging.getLogger(__name__) + + +@dataclass +class InstallFromPath: + """Installation from the local path.""" + + backend_path: Path + + +@dataclass +class DownloadAndInstall: + """Download and install.""" + + eula_agreement: bool = True + + +InstallationType = Union[InstallFromPath, DownloadAndInstall] + + +class Installation(ABC): + """Base class for the installation process of the backends.""" + + @property + @abstractmethod + def name(self) -> str: + """Return name of the backend.""" + + @property + @abstractmethod + def description(self) -> str: + """Return description of the backend.""" + + @property + @abstractmethod + def could_be_installed(self) -> bool: + """Return true if backend could be installed in current environment.""" + + @property + @abstractmethod + def already_installed(self) -> bool: + """Return true if backend is already installed.""" + + @abstractmethod + def supports(self, install_type: InstallationType) -> bool: + """Return true if installation supports requested installation type.""" + + @abstractmethod + def install(self, install_type: InstallationType) -> None: + """Install the backend.""" + + +InstallationFilter = Callable[[Installation], bool] + + +class AlreadyInstalledFilter: + """Filter for already installed backends.""" + + def __call__(self, installation: Installation) -> bool: + """Installation filter.""" + return installation.already_installed + + +class ReadyForInstallationFilter: + """Filter for ready to be installed backends.""" + + def __call__(self, installation: Installation) -> bool: + """Installation filter.""" + return installation.could_be_installed and not installation.already_installed + + +class SupportsInstallTypeFilter: + """Filter backends that support certain type of the installation.""" + + def __init__(self, installation_type: InstallationType) -> None: + """Init filter.""" + self.installation_type = installation_type + + def __call__(self, installation: Installation) -> bool: + """Installation filter.""" + return installation.supports(self.installation_type) + + +class SearchByNameFilter: + """Filter installation by name.""" + + def __init__(self, backend_name: Optional[str]) -> None: + """Init filter.""" + self.backend_name = backend_name + + def __call__(self, installation: Installation) -> bool: + """Installation filter.""" + return not self.backend_name or installation.name == self.backend_name + + +class InstallationManager(ABC): + """Helper class for managing installations.""" + + @abstractmethod + def install_from(self, backend_path: Path, backend_name: Optional[str]) -> None: + """Install backend from the local directory.""" + + @abstractmethod + def download_and_install( + self, backend_name: Optional[str], eula_agreement: bool + ) -> None: + """Download and install backends.""" + + @abstractmethod + def show_env_details(self) -> None: + """Show environment details.""" + + @abstractmethod + def backend_installed(self, backend_name: str) -> bool: + """Return true if requested backend installed.""" + + +class InstallationFiltersMixin: + """Mixin for filtering installation based on different conditions.""" + + installations: List[Installation] + + def filter_by(self, *filters: InstallationFilter) -> List[Installation]: + """Filter installations.""" + return [ + installation + for installation in self.installations + if all(filter_(installation) for filter_ in filters) + ] + + def could_be_installed_from( + self, backend_path: Path, backend_name: Optional[str] + ) -> List[Installation]: + """Return installations that could be installed from provided directory.""" + return self.filter_by( + SupportsInstallTypeFilter(InstallFromPath(backend_path)), + SearchByNameFilter(backend_name), + ) + + def could_be_downloaded_and_installed( + self, backend_name: Optional[str] = None + ) -> List[Installation]: + """Return installations that could be downloaded and installed.""" + return self.filter_by( + SupportsInstallTypeFilter(DownloadAndInstall()), + SearchByNameFilter(backend_name), + ReadyForInstallationFilter(), + ) + + def already_installed( + self, backend_name: Optional[str] = None + ) -> List[Installation]: + """Return list of backends that are already installed.""" + return self.filter_by( + AlreadyInstalledFilter(), SearchByNameFilter(backend_name) + ) + + def ready_for_installation(self) -> List[Installation]: + """Return list of the backends that could be installed.""" + return self.filter_by(ReadyForInstallationFilter()) + + +class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin): + """Interactive installation manager.""" + + def __init__( + self, installations: List[Installation], noninteractive: bool = False + ) -> None: + """Init the manager.""" + self.installations = installations + self.noninteractive = noninteractive + + def choose_installation_for_path( + self, backend_path: Path, backend_name: Optional[str] + ) -> Optional[Installation]: + """Check available installation and select one if possible.""" + installs = self.could_be_installed_from(backend_path, backend_name) + + if not installs: + logger.info( + "Unfortunatelly, it was not possible to automatically " + "detect type of the installed FVP. " + "Please, check provided path to the installed FVP." + ) + return None + + if len(installs) != 1: + names = ",".join((install.name for install in installs)) + logger.info( + "Unable to correctly detect type of the installed FVP." + "The following FVPs are detected %s. Installation skipped.", + names, + ) + return None + + installation = installs[0] + if installation.already_installed: + logger.info( + "%s was found in %s, but it has been already installed.", + installation.name, + backend_path, + ) + return None + + return installation + + def install_from(self, backend_path: Path, backend_name: Optional[str]) -> None: + """Install from the provided directory.""" + installation = self.choose_installation_for_path(backend_path, backend_name) + + if not installation: + return + + prompt = ( + f"{installation.name} was found in {backend_path}. " + "Would you like to install it?" + ) + self._install(installation, InstallFromPath(backend_path), prompt) + + def download_and_install( + self, backend_name: Optional[str] = None, eula_agreement: bool = True + ) -> None: + """Download and install available backends.""" + installations = self.could_be_downloaded_and_installed(backend_name) + + if not installations: + logger.info("No backends available for the installation.") + return + + names = ",".join((installation.name for installation in installations)) + logger.info("Following backends are available for downloading: %s", names) + + for installation in installations: + prompt = f"Would you like to download and install {installation.name}?" + self._install( + installation, DownloadAndInstall(eula_agreement=eula_agreement), prompt + ) + + def show_env_details(self) -> None: + """Print current state of the execution environment.""" + if installed := self.already_installed(): + logger.info("Installed backends:\n") + + for installation in installed: + logger.info(" - %s", installation.name) + + if could_be_installed := self.ready_for_installation(): + logger.info("Following backends could be installed:") + + for installation in could_be_installed: + logger.info(" - %s", installation.name) + + if not installed and not could_be_installed: + logger.info("No backends installed") + + def _install( + self, + installation: Installation, + installation_type: InstallationType, + prompt: str, + ) -> None: + proceed = self.noninteractive or yes(prompt) + + if proceed: + installation.install(installation_type) + logger.info("%s successfully installed.", installation.name) + else: + logger.info("%s installation canceled.", installation.name) + + def backend_installed(self, backend_name: str) -> bool: + """Return true if requested backend installed.""" + installations = self.already_installed(backend_name) + + return len(installations) == 1 diff --git a/src/mlia/tools/metadata/corstone.py b/src/mlia/tools/metadata/corstone.py new file mode 100644 index 0000000..7a9d113 --- /dev/null +++ b/src/mlia/tools/metadata/corstone.py @@ -0,0 +1,402 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for Corstone based FVPs.""" +import logging +import platform +import subprocess +import tarfile +from dataclasses import dataclass +from pathlib import Path +from typing import Callable +from typing import Iterable +from typing import List +from typing import Optional + +import mlia.tools.aiet_wrapper as aiet +from mlia.tools.metadata.common import DownloadAndInstall +from mlia.tools.metadata.common import Installation +from mlia.tools.metadata.common import InstallationType +from mlia.tools.metadata.common import InstallFromPath +from mlia.utils.download import DownloadArtifact +from mlia.utils.filesystem import all_files_exist +from mlia.utils.filesystem import all_paths_valid +from mlia.utils.filesystem import copy_all +from mlia.utils.filesystem import get_mlia_resources +from mlia.utils.filesystem import temp_directory +from mlia.utils.proc import working_directory + +logger = logging.getLogger(__name__) + + +@dataclass +class BackendInfo: + """Backend information.""" + + backend_path: Path + copy_source: bool = True + system_config: Optional[str] = None + + +PathChecker = Callable[[Path], Optional[BackendInfo]] +BackendInstaller = Callable[[bool, Path], Path] + + +class AIETMetadata: + """AIET installation metadata.""" + + def __init__( + self, + name: str, + description: str, + system_config: str, + apps_resources: List[str], + fvp_dir_name: str, + download_artifact: Optional[DownloadArtifact], + supported_platforms: Optional[List[str]] = None, + ) -> None: + """ + Initialize AIETMetaData. + + Members expected_systems and expected_apps are filled automatically. + """ + self.name = name + self.description = description + self.system_config = system_config + self.apps_resources = apps_resources + self.fvp_dir_name = fvp_dir_name + self.download_artifact = download_artifact + self.supported_platforms = supported_platforms + + self.expected_systems = aiet.get_all_system_names(name) + self.expected_apps = aiet.get_all_application_names(name) + + @property + def expected_resources(self) -> Iterable[Path]: + """Return list of expected resources.""" + resources = [self.system_config, *self.apps_resources] + + return (get_mlia_resources() / resource for resource in resources) + + @property + def supported_platform(self) -> bool: + """Return true if current platform supported.""" + if not self.supported_platforms: + return True + + return platform.system() in self.supported_platforms + + +class AIETBasedInstallation(Installation): + """Backend installation based on AIET functionality.""" + + def __init__( + self, + aiet_runner: aiet.AIETRunner, + metadata: AIETMetadata, + path_checker: PathChecker, + backend_installer: Optional[BackendInstaller], + ) -> None: + """Init the tool installation.""" + self.aiet_runner = aiet_runner + self.metadata = metadata + self.path_checker = path_checker + self.backend_installer = backend_installer + + @property + def name(self) -> str: + """Return name of the tool.""" + return self.metadata.name + + @property + def description(self) -> str: + """Return description of the tool.""" + return self.metadata.description + + @property + def already_installed(self) -> bool: + """Return true if tool already installed.""" + return self.aiet_runner.all_installed( + self.metadata.expected_systems, self.metadata.expected_apps + ) + + @property + def could_be_installed(self) -> bool: + """Return true if tool could be installed.""" + if not self.metadata.supported_platform: + return False + + return all_paths_valid(self.metadata.expected_resources) + + def supports(self, install_type: InstallationType) -> bool: + """Return true if tools supported type of the installation.""" + if isinstance(install_type, DownloadAndInstall): + return self.metadata.download_artifact is not None + + if isinstance(install_type, InstallFromPath): + return self.path_checker(install_type.backend_path) is not None + + return False # type: ignore + + def install(self, install_type: InstallationType) -> None: + """Install the tool.""" + if isinstance(install_type, DownloadAndInstall): + download_artifact = self.metadata.download_artifact + assert download_artifact is not None, "No artifact provided" + + self.download_and_install(download_artifact, install_type.eula_agreement) + elif isinstance(install_type, InstallFromPath): + backend_path = self.path_checker(install_type.backend_path) + assert backend_path is not None, "Unable to resolve backend path" + + self.install_from(backend_path) + else: + raise Exception(f"Unable to install {install_type}") + + def install_from(self, backend_info: BackendInfo) -> None: + """Install tool from the directory.""" + mlia_resources = get_mlia_resources() + + with temp_directory() as tmpdir: + fvp_dist_dir = tmpdir / self.metadata.fvp_dir_name + + system_config = self.metadata.system_config + if backend_info.system_config: + system_config = backend_info.system_config + + resources_to_copy = [mlia_resources / system_config] + if backend_info.copy_source: + resources_to_copy.append(backend_info.backend_path) + + copy_all(*resources_to_copy, dest=fvp_dist_dir) + + self.aiet_runner.install_system(fvp_dist_dir) + + for app in self.metadata.apps_resources: + self.aiet_runner.install_application(mlia_resources / app) + + def download_and_install( + self, download_artifact: DownloadArtifact, eula_agrement: bool + ) -> None: + """Download and install the tool.""" + with temp_directory() as tmpdir: + try: + downloaded_to = download_artifact.download_to(tmpdir) + except Exception as err: + raise Exception("Unable to download backend artifact") from err + + with working_directory(tmpdir / "dist", create_dir=True) as dist_dir: + with tarfile.open(downloaded_to) as archive: + archive.extractall(dist_dir) + + assert self.backend_installer, ( + f"Backend '{self.metadata.name}' does not support " + "download and installation." + ) + backend_path = self.backend_installer(eula_agrement, dist_dir) + if self.path_checker(backend_path) is None: + raise Exception("Downloaded artifact has invalid structure") + + self.install(InstallFromPath(backend_path)) + + +class PackagePathChecker: + """Package path checker.""" + + def __init__( + self, expected_files: List[str], backend_subfolder: Optional[str] = None + ) -> None: + """Init the path checker.""" + self.expected_files = expected_files + self.backend_subfolder = backend_subfolder + + def __call__(self, backend_path: Path) -> Optional[BackendInfo]: + """Check if directory contains all expected files.""" + resolved_paths = (backend_path / file for file in self.expected_files) + if not all_files_exist(resolved_paths): + return None + + if self.backend_subfolder: + subfolder = backend_path / self.backend_subfolder + + if not subfolder.is_dir(): + return None + + return BackendInfo(subfolder) + + return BackendInfo(backend_path) + + +class StaticPathChecker: + """Static path checker.""" + + def __init__( + self, + static_backend_path: Path, + expected_files: List[str], + copy_source: bool = False, + system_config: Optional[str] = None, + ) -> None: + """Init static path checker.""" + self.static_backend_path = static_backend_path + self.expected_files = expected_files + self.copy_source = copy_source + self.system_config = system_config + + def __call__(self, backend_path: Path) -> Optional[BackendInfo]: + """Check if directory equals static backend path with all expected files.""" + if backend_path != self.static_backend_path: + return None + + resolved_paths = (backend_path / file for file in self.expected_files) + if not all_files_exist(resolved_paths): + return None + + return BackendInfo( + backend_path, + copy_source=self.copy_source, + system_config=self.system_config, + ) + + +class CompoundPathChecker: + """Compound path checker.""" + + def __init__(self, *path_checkers: PathChecker) -> None: + """Init compound path checker.""" + self.path_checkers = path_checkers + + def __call__(self, backend_path: Path) -> Optional[BackendInfo]: + """Iterate over checkers and return first non empty backend info.""" + first_resolved_backend_info = ( + backend_info + for path_checker in self.path_checkers + if (backend_info := path_checker(backend_path)) is not None + ) + + return next(first_resolved_backend_info, None) + + +class Corstone300Installer: + """Helper class that wraps Corstone 300 installation logic.""" + + def __call__(self, eula_agreement: bool, dist_dir: Path) -> Path: + """Install Corstone-300 and return path to the models.""" + with working_directory(dist_dir): + install_dir = "corstone-300" + try: + fvp_install_cmd = [ + "./FVP_Corstone_SSE-300.sh", + "-q", + "-d", + install_dir, + ] + if not eula_agreement: + fvp_install_cmd += [ + "--nointeractive", + "--i-agree-to-the-contained-eula", + ] + + subprocess.check_call(fvp_install_cmd) + except subprocess.CalledProcessError as err: + raise Exception( + "Error occurred during Corstone-300 installation" + ) from err + + return dist_dir / install_dir + + +def get_corstone_300_installation() -> Installation: + """Get Corstone-300 installation.""" + corstone_300 = AIETBasedInstallation( + aiet_runner=aiet.get_aiet_runner(), + # pylint: disable=line-too-long + metadata=AIETMetadata( + name="Corstone-300", + description="Corstone-300 FVP", + system_config="aiet/systems/corstone-300/aiet-config.json", + apps_resources=[ + "aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA", + "aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA", + "aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA", + ], + fvp_dir_name="corstone_300", + download_artifact=DownloadArtifact( + name="Corstone-300 FVP", + url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.16_26.tgz", + filename="FVP_Corstone_SSE-300_11.16_26.tgz", + version="11.16_26", + sha256_hash="e26139be756b5003a30d978c629de638aed1934d597dc24a17043d4708e934d7", + ), + supported_platforms=["Linux"], + ), + # pylint: enable=line-too-long + path_checker=CompoundPathChecker( + PackagePathChecker( + expected_files=[ + "models/Linux64_GCC-6.4/FVP_Corstone_SSE-300_Ethos-U55", + "models/Linux64_GCC-6.4/FVP_Corstone_SSE-300_Ethos-U65", + ], + backend_subfolder="models/Linux64_GCC-6.4", + ), + StaticPathChecker( + static_backend_path=Path("/opt/VHT"), + expected_files=[ + "VHT_Corstone_SSE-300_Ethos-U55", + "VHT_Corstone_SSE-300_Ethos-U65", + ], + copy_source=False, + system_config="aiet/systems/corstone-300-vht/aiet-config.json", + ), + ), + backend_installer=Corstone300Installer(), + ) + + return corstone_300 + + +def get_corstone_310_installation() -> Installation: + """Get Corstone-310 installation.""" + corstone_310 = AIETBasedInstallation( + aiet_runner=aiet.get_aiet_runner(), + # pylint: disable=line-too-long + metadata=AIETMetadata( + name="Corstone-310", + description="Corstone-310 FVP", + system_config="aiet/systems/corstone-310/aiet-config.json", + apps_resources=[ + "aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA", + "aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA", + ], + fvp_dir_name="corstone_310", + download_artifact=None, + supported_platforms=["Linux"], + ), + # pylint: enable=line-too-long + path_checker=CompoundPathChecker( + PackagePathChecker( + expected_files=[ + "models/Linux64_GCC-9.3/FVP_Corstone_SSE-310", + ], + backend_subfolder="models/Linux64_GCC-9.3", + ), + StaticPathChecker( + static_backend_path=Path("/opt/VHT"), + expected_files=[ + "VHT_Corstone_SSE-310", + ], + copy_source=False, + system_config="aiet/systems/corstone-310-vht/aiet-config.json", + ), + ), + backend_installer=None, + ) + + return corstone_310 + + +def get_corstone_installations() -> List[Installation]: + """Get Corstone installations.""" + return [ + get_corstone_300_installation(), + get_corstone_310_installation(), + ] diff --git a/src/mlia/tools/vela_wrapper.py b/src/mlia/tools/vela_wrapper.py new file mode 100644 index 0000000..7225797 --- /dev/null +++ b/src/mlia/tools/vela_wrapper.py @@ -0,0 +1,500 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Vela wrapper module.""" +import itertools +import logging +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +from ethosu.vela.architecture_features import ArchitectureFeatures +from ethosu.vela.compiler_driver import compiler_driver +from ethosu.vela.compiler_driver import CompilerOptions +from ethosu.vela.compiler_driver import TensorAllocator +from ethosu.vela.model_reader import ModelReaderOptions +from ethosu.vela.model_reader import read_model +from ethosu.vela.nn_graph import Graph +from ethosu.vela.nn_graph import NetworkType +from ethosu.vela.npu_performance import PassCycles +from ethosu.vela.operation import CustomType +from ethosu.vela.operation import Op +from ethosu.vela.scheduler import OptimizationStrategy +from ethosu.vela.scheduler import SchedulerOptions +from ethosu.vela.tensor import BandwidthDirection +from ethosu.vela.tensor import MemArea +from ethosu.vela.tensor import Tensor +from ethosu.vela.tflite_mapping import optype_to_builtintype +from ethosu.vela.tflite_model_semantic import TFLiteSemantic +from ethosu.vela.tflite_supported_operators import TFLiteSupportedOperators +from ethosu.vela.tflite_writer import write_tflite +from ethosu.vela.vela import generate_supported_ops + +from mlia.utils.logging import redirect_output + + +logger = logging.getLogger(__name__) + +VELA_INTERNAL_OPS = (Op.Placeholder, Op.SubgraphInput, Op.Const) + + +@dataclass +class PerformanceMetrics: # pylint: disable=too-many-instance-attributes + """Contains all the performance metrics Vela generates in a run.""" + + npu_cycles: int + sram_access_cycles: int + dram_access_cycles: int + on_chip_flash_access_cycles: int + off_chip_flash_access_cycles: int + total_cycles: int + batch_inference_time: float + inferences_per_second: float + batch_size: int + unknown_memory_area_size: int + sram_memory_area_size: int + dram_memory_area_size: int + on_chip_flash_memory_area_size: int + off_chip_flash_memory_area_size: int + + +@dataclass +class NpuSupported: + """Operator's npu supported attribute.""" + + supported: bool + reasons: List[Tuple[str, str]] + + +@dataclass +class Operator: + """Model operator.""" + + name: str + op_type: str + run_on_npu: NpuSupported + + @property + def cpu_only(self) -> bool: + """Return true if operator is CPU only.""" + cpu_only_reasons = [("CPU only operator", "")] + return ( + not self.run_on_npu.supported + and self.run_on_npu.reasons == cpu_only_reasons + ) + + +@dataclass +class Operators: + """Model's operators.""" + + ops: List[Operator] + + @property + def npu_supported_ratio(self) -> float: + """Return NPU supported ratio.""" + total = self.total_number + npu_supported = self.npu_supported_number + + if total == 0 or npu_supported == 0: + return 0 + + return npu_supported / total + + @property + def npu_unsupported_ratio(self) -> float: + """Return NPU unsupported ratio.""" + return 1 - self.npu_supported_ratio + + @property + def total_number(self) -> int: + """Return total number of operators.""" + return len(self.ops) + + @property + def npu_supported_number(self) -> int: + """Return number of npu supported operators.""" + return sum(op.run_on_npu.supported for op in self.ops) + + +@dataclass +class Model: + """Model metadata.""" + + nng: Graph + network_type: NetworkType + + @property + def optimized(self) -> bool: + """Return true if model is already optimized.""" + return any( + op.attrs.get("custom_type") == CustomType.ExistingNpuOp + for sg in self.nng.subgraphs + for op in sg.get_all_ops() + ) + + +@dataclass +class OptimizedModel: + """Instance of the Vela optimized model.""" + + nng: Graph + arch: ArchitectureFeatures + compiler_options: CompilerOptions + scheduler_options: SchedulerOptions + + def save(self, output_filename: Union[str, Path]) -> None: + """Save instance of the optimized model to the file.""" + write_tflite(self.nng, output_filename) + + +AcceleratorConfigType = Literal[ + "ethos-u55-32", + "ethos-u55-64", + "ethos-u55-128", + "ethos-u55-256", + "ethos-u65-256", + "ethos-u65-512", +] + +TensorAllocatorType = Literal["LinearAlloc", "Greedy", "HillClimb"] + +OptimizationStrategyType = Literal["Performance", "Size"] + + +@dataclass +class VelaCompilerOptions: # pylint: disable=too-many-instance-attributes + """Vela compiler options.""" + + config_files: Optional[Union[str, List[str]]] = None + system_config: str = ArchitectureFeatures.DEFAULT_CONFIG + memory_mode: str = ArchitectureFeatures.DEFAULT_CONFIG + accelerator_config: Optional[AcceleratorConfigType] = None + max_block_dependency: int = ArchitectureFeatures.MAX_BLOCKDEP + arena_cache_size: Optional[int] = None + tensor_allocator: TensorAllocatorType = "HillClimb" + cpu_tensor_alignment: int = Tensor.AllocationQuantum + optimization_strategy: OptimizationStrategyType = "Performance" + output_dir: Optional[str] = None + recursion_limit: int = 1000 + + +class VelaCompiler: # pylint: disable=too-many-instance-attributes + """Vela compiler wrapper.""" + + def __init__(self, compiler_options: VelaCompilerOptions): + """Init Vela wrapper instance.""" + self.config_files = compiler_options.config_files + self.system_config = compiler_options.system_config + self.memory_mode = compiler_options.memory_mode + self.accelerator_config = compiler_options.accelerator_config + self.max_block_dependency = compiler_options.max_block_dependency + self.arena_cache_size = compiler_options.arena_cache_size + self.tensor_allocator = TensorAllocator[compiler_options.tensor_allocator] + self.cpu_tensor_alignment = compiler_options.cpu_tensor_alignment + self.optimization_strategy = OptimizationStrategy[ + compiler_options.optimization_strategy + ] + self.output_dir = compiler_options.output_dir + self.recursion_limit = compiler_options.recursion_limit + + sys.setrecursionlimit(self.recursion_limit) + + def read_model(self, model: Union[str, Path]) -> Model: + """Read model.""" + logger.debug("Read model %s", model) + + nng, network_type = self._read_model(model) + return Model(nng, network_type) + + def compile_model(self, model: Union[str, Path, Model]) -> OptimizedModel: + """Compile the model.""" + if isinstance(model, (str, Path)): + nng, network_type = self._read_model(model) + else: + nng, network_type = model.nng, NetworkType.TFLite + + if not nng: + raise Exception("Unable to read model") + + try: + arch = self._architecture_features() + compiler_options = self._compiler_options() + scheduler_options = self._scheduler_options() + + with redirect_output( + logger, stdout_level=logging.DEBUG, stderr_level=logging.DEBUG + ): + compiler_driver( + nng, arch, compiler_options, scheduler_options, network_type + ) + + return OptimizedModel(nng, arch, compiler_options, scheduler_options) + except (SystemExit, Exception) as err: + raise Exception("Model could not be optimized with Vela compiler") from err + + def get_config(self) -> Dict[str, Any]: + """Get compiler configuration.""" + arch = self._architecture_features() + + memory_area = { + mem.name: { + "clock_scales": arch.memory_clock_scales[mem], + "burst_length": arch.memory_burst_length[mem], + "read_latency": arch.memory_latency[mem][BandwidthDirection.Read], + "write_latency": arch.memory_latency[mem][BandwidthDirection.Write], + } + for mem in ( + MemArea.Sram, + MemArea.Dram, + MemArea.OnChipFlash, + MemArea.OffChipFlash, + ) + } + + return { + "accelerator_config": arch.accelerator_config.value, + "system_config": arch.system_config, + "core_clock": arch.core_clock, + "axi0_port": arch.axi0_port.name, + "axi1_port": arch.axi1_port.name, + "memory_mode": arch.memory_mode, + "const_mem_area": arch.const_mem_area.name, + "arena_mem_area": arch.arena_mem_area.name, + "cache_mem_area": arch.cache_mem_area.name, + "arena_cache_size": arch.arena_cache_size, + "permanent_storage_mem_area": arch.permanent_storage_mem_area.name, + "feature_map_storage_mem_area": arch.feature_map_storage_mem_area.name, + "fast_storage_mem_area": arch.fast_storage_mem_area.name, + "memory_area": memory_area, + } + + @staticmethod + def _read_model(model: Union[str, Path]) -> Tuple[Graph, NetworkType]: + """Read TFLite model.""" + try: + model_path = str(model) if isinstance(model, Path) else model + + with redirect_output( + logger, stdout_level=logging.DEBUG, stderr_level=logging.DEBUG + ): + return read_model(model_path, ModelReaderOptions()) # type: ignore + except (SystemExit, Exception) as err: + raise Exception(f"Unable to read model {model_path}") from err + + def _architecture_features(self) -> ArchitectureFeatures: + """Return ArchitectureFeatures instance.""" + return ArchitectureFeatures( + vela_config_files=self.config_files, + accelerator_config=self.accelerator_config, + system_config=self.system_config, + memory_mode=self.memory_mode, + max_blockdep=self.max_block_dependency, + verbose_config=False, + arena_cache_size=self.arena_cache_size, + ) + + def _scheduler_options(self) -> SchedulerOptions: + """Return SchedulerOptions instance.""" + arch = self._architecture_features() + + return SchedulerOptions( + optimization_strategy=self.optimization_strategy, + sram_target=arch.arena_cache_size, + verbose_schedule=False, + ) + + def _compiler_options(self) -> CompilerOptions: + """Return CompilerOptions instance.""" + return CompilerOptions( + verbose_graph=False, + verbose_quantization=False, + verbose_packing=False, + verbose_tensor_purpose=False, + verbose_tensor_format=False, + verbose_allocation=False, + verbose_high_level_command_stream=False, + verbose_register_command_stream=False, + verbose_operators=False, + verbose_weights=False, + show_cpu_operations=False, + tensor_allocator=self.tensor_allocator, + timing=False, + output_dir=self.output_dir, + cpu_tensor_alignment=self.cpu_tensor_alignment, + ) + + +def resolve_compiler_config( + vela_compiler_options: VelaCompilerOptions, +) -> Dict[str, Any]: + """Resolve passed compiler options. + + Vela has number of configuration parameters that being + resolved during passing compiler options. E.g. Vela + reads configuration parameters from vela.ini and fills + it's internal structures with resolved values (memory mode, + system mode, etc.). + + In order to get this information we need to create + instance of the Vela compiler first. + """ + vela_compiler = VelaCompiler(vela_compiler_options) + return vela_compiler.get_config() + + +def estimate_performance( + model_path: Path, compiler_options: VelaCompilerOptions +) -> PerformanceMetrics: + """Return performance estimations for the model/device. + + Logic for this function comes from Vela module stats_writer.py + """ + logger.debug( + "Estimate performance for the model %s on %s", + model_path, + compiler_options.accelerator_config, + ) + + vela_compiler = VelaCompiler(compiler_options) + + initial_model = vela_compiler.read_model(model_path) + if initial_model.optimized: + raise Exception("Unable to estimate performance for the given optimized model") + + optimized_model = vela_compiler.compile_model(initial_model) + + return _performance_metrics(optimized_model) + + +def optimize_model( + model_path: Path, compiler_options: VelaCompilerOptions, output_model_path: Path +) -> None: + """Optimize model and return it's path after optimization.""" + logger.debug( + "Optimize model %s for device %s", + model_path, + compiler_options.accelerator_config, + ) + + vela_compiler = VelaCompiler(compiler_options) + optimized_model = vela_compiler.compile_model(model_path) + + logger.debug("Save optimized model into %s", output_model_path) + optimized_model.save(output_model_path) + + +def _performance_metrics(optimized_model: OptimizedModel) -> PerformanceMetrics: + """Return performance metrics for optimized model.""" + cycles = optimized_model.nng.cycles + + def memory_usage(mem_area: MemArea) -> int: + """Get memory usage for the proviced memory area type.""" + memory_used: Dict[MemArea, int] = optimized_model.nng.memory_used + bandwidths = optimized_model.nng.bandwidths + + return memory_used.get(mem_area, 0) if np.sum(bandwidths[mem_area]) > 0 else 0 + + midpoint_fps = np.nan + midpoint_inference_time = cycles[PassCycles.Total] / optimized_model.arch.core_clock + if midpoint_inference_time > 0: + midpoint_fps = 1 / midpoint_inference_time + + return PerformanceMetrics( + npu_cycles=int(cycles[PassCycles.Npu]), + sram_access_cycles=int(cycles[PassCycles.SramAccess]), + dram_access_cycles=int(cycles[PassCycles.DramAccess]), + on_chip_flash_access_cycles=int(cycles[PassCycles.OnChipFlashAccess]), + off_chip_flash_access_cycles=int(cycles[PassCycles.OffChipFlashAccess]), + total_cycles=int(cycles[PassCycles.Total]), + batch_inference_time=midpoint_inference_time * 1000, + inferences_per_second=midpoint_fps, + batch_size=optimized_model.nng.batch_size, + unknown_memory_area_size=memory_usage(MemArea.Unknown), + sram_memory_area_size=memory_usage(MemArea.Sram), + dram_memory_area_size=memory_usage(MemArea.Dram), + on_chip_flash_memory_area_size=memory_usage(MemArea.OnChipFlash), + off_chip_flash_memory_area_size=memory_usage(MemArea.OffChipFlash), + ) + + +def supported_operators( + model_path: Path, compiler_options: VelaCompilerOptions +) -> Operators: + """Return list of model's operators.""" + logger.debug("Check supported operators for the model %s", model_path) + + vela_compiler = VelaCompiler(compiler_options) + initial_model = vela_compiler.read_model(model_path) + + return Operators( + [ + Operator(op.name, optype_to_builtintype(op.type), run_on_npu(op)) + for sg in initial_model.nng.subgraphs + for op in sg.get_all_ops() + if op.type not in VELA_INTERNAL_OPS + ] + ) + + +def run_on_npu(operator: Op) -> NpuSupported: + """Return information if operator can run on NPU. + + Vela does a number of checks that can help establish whether + a particular operator is supported to run on NPU. + + There are two groups of checks: + - general TFLite constraints + - operator specific constraints + + If an operator is not supported on NPU then this function + will return the reason of that. + + The reason is split in two parts: + - general description of why the operator cannot be placed on NPU + - details on the particular operator + """ + semantic_checker = TFLiteSemantic() + semantic_constraints = itertools.chain( + semantic_checker.generic_constraints, + semantic_checker.specific_constraints[operator.type], + ) + + for constraint in semantic_constraints: + op_valid, op_reason = constraint(operator) + if not op_valid: + return NpuSupported(False, [(constraint.__doc__, op_reason)]) + + if operator.type not in TFLiteSupportedOperators.supported_operators: + reasons = ( + [("CPU only operator", "")] + if operator.type not in VELA_INTERNAL_OPS + else [] + ) + + return NpuSupported(False, reasons) + + tflite_supported_operators = TFLiteSupportedOperators() + operation_constraints = itertools.chain( + tflite_supported_operators.generic_constraints, + tflite_supported_operators.specific_constraints[operator.type], + ) + for constraint in operation_constraints: + op_valid, op_reason = constraint(operator) + if not op_valid: + return NpuSupported(False, [(constraint.__doc__, op_reason)]) + + return NpuSupported(True, []) + + +def generate_supported_operators_report() -> None: + """Generate supported operators report in current working directory.""" + with redirect_output(logger): + generate_supported_ops() diff --git a/src/mlia/utils/__init__.py b/src/mlia/utils/__init__.py new file mode 100644 index 0000000..ecb5ca1 --- /dev/null +++ b/src/mlia/utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Utils module.""" diff --git a/src/mlia/utils/console.py b/src/mlia/utils/console.py new file mode 100644 index 0000000..7cb3d83 --- /dev/null +++ b/src/mlia/utils/console.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Console output utility functions.""" +from typing import Iterable +from typing import List +from typing import Optional + +from rich.console import Console +from rich.console import RenderableType +from rich.table import box +from rich.table import Table +from rich.text import Text + + +def create_section_header( + section_name: Optional[str] = None, length: int = 80, sep: str = "-" +) -> str: + """Return section header.""" + if not section_name: + content = sep * length + else: + before = 3 + spaces = 2 + after = length - (len(section_name) + before + spaces) + if after < 0: + raise ValueError("Section name too long") + content = f"{sep * before} {section_name} {sep * after}" + + return f"\n{content}\n" + + +def apply_style(value: str, style: str) -> str: + """Apply style to the value.""" + return f"[{style}]{value}" + + +def style_improvement(result: bool) -> str: + """Return different text style based on result.""" + return "green" if result else "yellow" + + +def produce_table( + rows: Iterable, + headers: Optional[List[str]] = None, + table_style: str = "default", +) -> str: + """Represent data in tabular form.""" + table = _get_table(table_style) + + if headers: + table.show_header = True + for header in headers: + table.add_column(header) + + for row in rows: + table.add_row(*row) + + return _convert_to_text(table) + + +def _get_table(table_style: str) -> Table: + """Get Table instance for the provided style.""" + if table_style == "default": + return Table( + show_header=False, + show_lines=True, + box=box.SQUARE_DOUBLE_HEAD, + ) + + if table_style == "nested": + return Table( + show_header=False, + box=None, + padding=(0, 1, 1, 0), + ) + + if table_style == "no_borders": + return Table(show_header=False, box=None) + + raise Exception(f"Unsupported table style {table_style}") + + +def _convert_to_text(*renderables: RenderableType) -> str: + """Convert renderable object to text.""" + console = Console() + with console.capture() as capture: + for item in renderables: + console.print(item) + + text = capture.get() + return text.rstrip() + + +def remove_ascii_codes(value: str) -> str: + """Decode and remove ASCII codes.""" + text = Text.from_ansi(value) + return text.plain diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py new file mode 100644 index 0000000..4658738 --- /dev/null +++ b/src/mlia/utils/download.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Utils for files downloading.""" +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable +from typing import List +from typing import Optional + +import requests +from rich.progress import BarColumn +from rich.progress import DownloadColumn +from rich.progress import FileSizeColumn +from rich.progress import Progress +from rich.progress import ProgressColumn +from rich.progress import TextColumn + +from mlia.utils.filesystem import sha256 +from mlia.utils.types import parse_int + + +def download_progress( + content_chunks: Iterable[bytes], content_length: Optional[int], label: Optional[str] +) -> Iterable[bytes]: + """Show progress info while reading content.""" + columns: List[ProgressColumn] = [TextColumn("{task.description}")] + + if content_length is None: + total = float("inf") + columns.append(FileSizeColumn()) + else: + total = content_length + columns.extend([BarColumn(), DownloadColumn(binary_units=True)]) + + with Progress(*columns) as progress: + task = progress.add_task(label or "Downloading", total=total) + + for chunk in content_chunks: + progress.update(task, advance=len(chunk)) + yield chunk + + +def download( + url: str, + dest: Path, + show_progress: bool = False, + label: Optional[str] = None, + chunk_size: int = 8192, +) -> None: + """Download the file.""" + with requests.get(url, stream=True) as resp: + resp.raise_for_status() + content_chunks = resp.iter_content(chunk_size=chunk_size) + + if show_progress: + content_length = parse_int(resp.headers.get("Content-Length")) + content_chunks = download_progress(content_chunks, content_length, label) + + with open(dest, "wb") as file: + for chunk in content_chunks: + file.write(chunk) + + +@dataclass +class DownloadArtifact: + """Download artifact attributes.""" + + name: str + url: str + filename: str + version: str + sha256_hash: str + + def download_to(self, dest_dir: Path, show_progress: bool = True) -> Path: + """Download artifact into destination directory.""" + if (dest := dest_dir / self.filename).exists(): + raise ValueError(f"{dest} already exists") + + download( + self.url, + dest, + show_progress=show_progress, + label=f"Downloading {self.name} ver. {self.version}", + ) + + if sha256(dest) != self.sha256_hash: + raise ValueError("Digests do not match") + + return dest diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py new file mode 100644 index 0000000..73a88d9 --- /dev/null +++ b/src/mlia/utils/filesystem.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Utils related to file management.""" +import hashlib +import importlib.resources as pkg_resources +import json +import os +import shutil +from contextlib import contextmanager +from pathlib import Path +from tempfile import mkstemp +from tempfile import TemporaryDirectory +from typing import Any +from typing import Dict +from typing import Generator +from typing import Iterable +from typing import List +from typing import Optional +from typing import Union + + +def get_mlia_resources() -> Path: + """Get the path to the resources directory.""" + with pkg_resources.path("mlia", "__init__.py") as init_path: + project_root = init_path.parent + return project_root / "resources" + + +def get_vela_config() -> Path: + """Get the path to the default Vela config file.""" + return get_mlia_resources() / "vela/vela.ini" + + +def get_profiles_file() -> Path: + """Get the Ethos-U profiles file.""" + return get_mlia_resources() / "profiles.json" + + +def get_profiles_data() -> Dict[str, Dict[str, Any]]: + """Get the Ethos-U profile values as a dictionary.""" + with open(get_profiles_file(), encoding="utf-8") as json_file: + profiles = json.load(json_file) + + if not isinstance(profiles, dict): + raise Exception("Profiles data format is not valid") + + return profiles + + +def get_profile(target: str) -> Dict[str, Any]: + """Get settings for the provided target profile.""" + profiles = get_profiles_data() + + if target not in profiles: + raise Exception(f"Unable to find target profile {target}") + + return profiles[target] + + +def get_supported_profile_names() -> List[str]: + """Get the supported Ethos-U profile names.""" + return list(get_profiles_data().keys()) + + +@contextmanager +def temp_file(suffix: Optional[str] = None) -> Generator[Path, None, None]: + """Create temp file and remove it after.""" + _, tmp_file = mkstemp(suffix=suffix) + + try: + yield Path(tmp_file) + finally: + os.remove(tmp_file) + + +@contextmanager +def temp_directory(suffix: Optional[str] = None) -> Generator[Path, None, None]: + """Create temp directory and remove it after.""" + with TemporaryDirectory(suffix=suffix) as tmpdir: + yield Path(tmpdir) + + +def file_chunks( + filepath: Union[Path, str], chunk_size: int = 4096 +) -> Generator[bytes, None, None]: + """Return sequence of the file chunks.""" + with open(filepath, "rb") as file: + while data := file.read(chunk_size): + yield data + + +def hexdigest(filepath: Union[Path, str], hash_obj: "hashlib._Hash") -> str: + """Return hex digest of the file.""" + for chunk in file_chunks(filepath): + hash_obj.update(chunk) + + return hash_obj.hexdigest() + + +def sha256(filepath: Path) -> str: + """Return SHA256 hash of the file.""" + return hexdigest(filepath, hashlib.sha256()) + + +def all_files_exist(paths: Iterable[Path]) -> bool: + """Check if all files are exist.""" + return all(item.is_file() for item in paths) + + +def all_paths_valid(paths: Iterable[Path]) -> bool: + """Check if all paths are valid.""" + return all(item.exists() for item in paths) + + +def copy_all(*paths: Path, dest: Path) -> None: + """Copy files/directories into destination folder.""" + dest.mkdir(exist_ok=True) + + for path in paths: + if path.is_file(): + shutil.copy2(path, dest) + + if path.is_dir(): + shutil.copytree(path, dest, dirs_exist_ok=True) diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py new file mode 100644 index 0000000..86d7567 --- /dev/null +++ b/src/mlia/utils/logging.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Logging utility functions.""" +import logging +from contextlib import contextmanager +from contextlib import ExitStack +from contextlib import redirect_stderr +from contextlib import redirect_stdout +from pathlib import Path +from typing import Any +from typing import Callable +from typing import Generator +from typing import List +from typing import Optional + + +class LoggerWriter: + """Redirect printed messages to the logger.""" + + def __init__(self, logger: logging.Logger, level: int): + """Init logger writer.""" + self.logger = logger + self.level = level + + def write(self, message: str) -> None: + """Write message.""" + if message.strip() != "": + self.logger.log(self.level, message) + + def flush(self) -> None: + """Flush buffers.""" + + +@contextmanager +def redirect_output( + logger: logging.Logger, + stdout_level: int = logging.INFO, + stderr_level: int = logging.INFO, +) -> Generator[None, None, None]: + """Redirect standard output to the logger.""" + stdout_to_log = LoggerWriter(logger, stdout_level) + stderr_to_log = LoggerWriter(logger, stderr_level) + + with ExitStack() as exit_stack: + exit_stack.enter_context(redirect_stdout(stdout_to_log)) # type: ignore + exit_stack.enter_context(redirect_stderr(stderr_to_log)) # type: ignore + + yield + + +class LogFilter(logging.Filter): + """Configurable log filter.""" + + def __init__(self, log_record_filter: Callable[[logging.LogRecord], bool]) -> None: + """Init log filter instance.""" + super().__init__() + self.log_record_filter = log_record_filter + + def filter(self, record: logging.LogRecord) -> bool: + """Filter log messages.""" + return self.log_record_filter(record) + + @classmethod + def equals(cls, log_level: int) -> "LogFilter": + """Return log filter that filters messages by log level.""" + + def filter_by_level(log_record: logging.LogRecord) -> bool: + return log_record.levelno == log_level + + return cls(filter_by_level) + + @classmethod + def skip(cls, log_level: int) -> "LogFilter": + """Return log filter that skips messages with particular level.""" + + def skip_by_level(log_record: logging.LogRecord) -> bool: + return log_record.levelno != log_level + + return cls(skip_by_level) + + +def create_log_handler( + *, + file_path: Optional[Path] = None, + stream: Optional[Any] = None, + log_level: Optional[int] = None, + log_format: Optional[str] = None, + log_filter: Optional[logging.Filter] = None, + delay: bool = True, +) -> logging.Handler: + """Create logger handler.""" + handler: Optional[logging.Handler] = None + + if file_path is not None: + handler = logging.FileHandler(file_path, delay=delay) + elif stream is not None: + handler = logging.StreamHandler(stream) + + if handler is None: + raise Exception("Unable to create logging handler") + + if log_level: + handler.setLevel(log_level) + + if log_format: + handler.setFormatter(logging.Formatter(log_format)) + + if log_filter: + handler.addFilter(log_filter) + + return handler + + +def attach_handlers( + handlers: List[logging.Handler], loggers: List[logging.Logger] +) -> None: + """Attach handlers to the loggers.""" + for handler in handlers: + for logger in loggers: + logger.addHandler(handler) diff --git a/src/mlia/utils/misc.py b/src/mlia/utils/misc.py new file mode 100644 index 0000000..de95448 --- /dev/null +++ b/src/mlia/utils/misc.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Various util functions.""" + + +def yes(prompt: str) -> bool: + """Return true if user confirms the action.""" + response = input(f"{prompt} [y/n]: ") + return response in ["y", "Y"] diff --git a/src/mlia/utils/proc.py b/src/mlia/utils/proc.py new file mode 100644 index 0000000..39aca43 --- /dev/null +++ b/src/mlia/utils/proc.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Utils related to process management.""" +import os +import signal +import subprocess +import time +from abc import ABC +from abc import abstractmethod +from contextlib import contextmanager +from contextlib import suppress +from pathlib import Path +from typing import Any +from typing import Generator +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + + +class OutputConsumer(ABC): + """Base class for the output consumers.""" + + @abstractmethod + def feed(self, line: str) -> None: + """Feed new line to the consumerr.""" + + +class RunningCommand: + """Running command.""" + + def __init__(self, process: subprocess.Popen) -> None: + """Init running command instance.""" + self.process = process + self._output_consumers: Optional[List[OutputConsumer]] = None + + def is_alive(self) -> bool: + """Return true if process is still alive.""" + return self.process.poll() is None + + def exit_code(self) -> Optional[int]: + """Return process's return code.""" + return self.process.poll() + + def stdout(self) -> Iterable[str]: + """Return std output of the process.""" + assert self.process.stdout is not None + + for line in self.process.stdout: + yield line + + def kill(self) -> None: + """Kill the process.""" + self.process.kill() + + def send_signal(self, signal_num: int) -> None: + """Send signal to the process.""" + self.process.send_signal(signal_num) + + @property + def output_consumers(self) -> Optional[List[OutputConsumer]]: + """Property output_consumers.""" + return self._output_consumers + + @output_consumers.setter + def output_consumers(self, output_consumers: List[OutputConsumer]) -> None: + """Set output consumers.""" + self._output_consumers = output_consumers + + def consume_output(self) -> None: + """Pass program's output to the consumers.""" + if self.process is None or self.output_consumers is None: + return + + for line in self.stdout(): + for consumer in self.output_consumers: + with suppress(): + consumer.feed(line) + + def stop( + self, wait: bool = True, num_of_attempts: int = 5, interval: float = 0.5 + ) -> None: + """Stop execution.""" + try: + if not self.is_alive(): + return + + self.process.send_signal(signal.SIGINT) + self.consume_output() + + if not wait: + return + + for _ in range(num_of_attempts): + time.sleep(interval) + if not self.is_alive(): + break + else: + raise Exception("Unable to stop running command") + finally: + self._close_fd() + + def _close_fd(self) -> None: + """Close file descriptors.""" + + def close(file_descriptor: Any) -> None: + """Check and close file.""" + if file_descriptor is not None and hasattr(file_descriptor, "close"): + file_descriptor.close() + + close(self.process.stdout) + close(self.process.stderr) + + def wait(self, redirect_output: bool = False) -> None: + """Redirect process output to stdout and wait for completion.""" + if redirect_output: + for line in self.stdout(): + print(line, end="") + + self.process.wait() + + +class CommandExecutor: + """Command executor.""" + + @staticmethod + def execute(command: List[str]) -> Tuple[int, bytes, bytes]: + """Execute the command.""" + result = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True + ) + + return (result.returncode, result.stdout, result.stderr) + + @staticmethod + def submit(command: List[str]) -> RunningCommand: + """Submit command for the execution.""" + process = subprocess.Popen( # pylint: disable=consider-using-with + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # redirect command stderr to stdout + universal_newlines=True, + bufsize=1, + ) + + return RunningCommand(process) + + +@contextmanager +def working_directory( + working_dir: Path, create_dir: bool = False +) -> Generator[Path, None, None]: + """Temporary change working directory.""" + current_working_dir = Path.cwd() + + if create_dir: + working_dir.mkdir() + + os.chdir(working_dir) + + try: + yield working_dir + finally: + os.chdir(current_working_dir) diff --git a/src/mlia/utils/types.py b/src/mlia/utils/types.py new file mode 100644 index 0000000..9b63928 --- /dev/null +++ b/src/mlia/utils/types.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Types related utility functions.""" +from typing import Any +from typing import Optional + + +def is_list_of(data: Any, cls: type, elem_num: Optional[int] = None) -> bool: + """Check if data is a list of object of the same class.""" + return ( + isinstance(data, (tuple, list)) + and all(isinstance(item, cls) for item in data) + and (elem_num is None or len(data) == elem_num) + ) + + +def is_number(value: str) -> bool: + """Return true if string contains a number.""" + try: + float(value) + except ValueError: + return False + + return True + + +def parse_int(value: Any, default: Optional[int] = None) -> Optional[int]: + """Parse integer value.""" + try: + return int(value) + except (TypeError, ValueError): + return default + + +def only_one_selected(*options: bool) -> bool: + """Return true if only one True value found.""" + return sum(options) == 1 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..4a1e153 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests module.""" diff --git a/tests/aiet/__init__.py b/tests/aiet/__init__.py new file mode 100644 index 0000000..873a7df --- /dev/null +++ b/tests/aiet/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""AIET tests module.""" diff --git a/tests/aiet/conftest.py b/tests/aiet/conftest.py new file mode 100644 index 0000000..cab3dc2 --- /dev/null +++ b/tests/aiet/conftest.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=redefined-outer-name +"""conftest for pytest.""" +import shutil +import tarfile +from pathlib import Path +from typing import Any + +import pytest +from click.testing import CliRunner + +from aiet.backend.common import get_backend_configs + + +@pytest.fixture(scope="session") +def test_systems_path(test_resources_path: Path) -> Path: + """Return test systems path in a pytest fixture.""" + return test_resources_path / "systems" + + +@pytest.fixture(scope="session") +def test_applications_path(test_resources_path: Path) -> Path: + """Return test applications path in a pytest fixture.""" + return test_resources_path / "applications" + + +@pytest.fixture(scope="session") +def test_tools_path(test_resources_path: Path) -> Path: + """Return test tools path in a pytest fixture.""" + return test_resources_path / "tools" + + +@pytest.fixture(scope="session") +def test_resources_path() -> Path: + """Return test resources path in a pytest fixture.""" + current_path = Path(__file__).parent.absolute() + return current_path / "test_resources" + + +@pytest.fixture(scope="session") +def non_optimised_input_model_file(test_tflite_model: Path) -> Path: + """Provide the path to a quantized dummy model file.""" + return test_tflite_model + + +@pytest.fixture(scope="session") +def optimised_input_model_file(test_tflite_vela_model: Path) -> Path: + """Provide path to Vela-optimised dummy model file.""" + return test_tflite_vela_model + + +@pytest.fixture(scope="session") +def invalid_input_model_file(test_tflite_invalid_model: Path) -> Path: + """Provide the path to an invalid dummy model file.""" + return test_tflite_invalid_model + + +@pytest.fixture(autouse=True) +def test_resources(monkeypatch: pytest.MonkeyPatch, test_resources_path: Path) -> Any: + """Force using test resources as middleware's repository.""" + + def get_test_resources() -> Path: + """Return path to the test resources.""" + return test_resources_path + + monkeypatch.setattr("aiet.utils.fs.get_aiet_resources", get_test_resources) + yield + + +@pytest.fixture(scope="session", autouse=True) +def add_tools(test_resources_path: Path) -> Any: + """Symlink the tools from the original resources path to the test resources path.""" + # tool_dirs = get_available_tool_directory_names() + tool_dirs = [cfg.parent for cfg in get_backend_configs("tools")] + + links = { + src_dir: (test_resources_path / "tools" / src_dir.name) for src_dir in tool_dirs + } + for src_dir, dst_dir in links.items(): + if not dst_dir.exists(): + dst_dir.symlink_to(src_dir, target_is_directory=True) + yield + # Remove symlinks + for dst_dir in links.values(): + if dst_dir.is_symlink(): + dst_dir.unlink() + + +def create_archive( + archive_name: str, source: Path, destination: Path, with_root_folder: bool = False +) -> None: + """Create archive from directory source.""" + with tarfile.open(destination / archive_name, mode="w:gz") as tar: + for item in source.iterdir(): + item_name = item.name + if with_root_folder: + item_name = f"{source.name}/{item_name}" + tar.add(item, item_name) + + +def process_directory(source: Path, destination: Path) -> None: + """Process resource directory.""" + destination.mkdir() + + for item in source.iterdir(): + if item.is_dir(): + create_archive(f"{item.name}.tar.gz", item, destination) + create_archive(f"{item.name}_dir.tar.gz", item, destination, True) + + +@pytest.fixture(scope="session", autouse=True) +def add_archives( + test_resources_path: Path, tmp_path_factory: pytest.TempPathFactory +) -> Any: + """Generate archives of the test resources.""" + tmp_path = tmp_path_factory.mktemp("archives") + + archives_path = tmp_path / "archives" + archives_path.mkdir() + + if (archives_path_link := test_resources_path / "archives").is_symlink(): + archives_path.unlink() + + archives_path_link.symlink_to(archives_path, target_is_directory=True) + + for item in ["applications", "systems"]: + process_directory(test_resources_path / item, archives_path / item) + + yield + + archives_path_link.unlink() + shutil.rmtree(tmp_path) + + +@pytest.fixture(scope="module") +def cli_runner() -> CliRunner: + """Return CliRunner instance in a pytest fixture.""" + return CliRunner() diff --git a/tests/aiet/test_backend_application.py b/tests/aiet/test_backend_application.py new file mode 100644 index 0000000..abfab00 --- /dev/null +++ b/tests/aiet/test_backend_application.py @@ -0,0 +1,452 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Tests for the application backend.""" +from collections import Counter +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import List +from unittest.mock import MagicMock + +import pytest + +from aiet.backend.application import Application +from aiet.backend.application import get_application +from aiet.backend.application import get_available_application_directory_names +from aiet.backend.application import get_available_applications +from aiet.backend.application import get_unique_application_names +from aiet.backend.application import install_application +from aiet.backend.application import load_applications +from aiet.backend.application import remove_application +from aiet.backend.common import Command +from aiet.backend.common import DataPaths +from aiet.backend.common import Param +from aiet.backend.common import UserParamConfig +from aiet.backend.config import ApplicationConfig +from aiet.backend.config import ExtendedApplicationConfig +from aiet.backend.config import NamedExecutionConfig + + +def test_get_available_application_directory_names() -> None: + """Test get_available_applicationss mocking get_resources.""" + directory_names = get_available_application_directory_names() + assert Counter(directory_names) == Counter( + ["application1", "application2", "application4", "application5"] + ) + + +def test_get_available_applications() -> None: + """Test get_available_applicationss mocking get_resources.""" + available_applications = get_available_applications() + + assert all(isinstance(s, Application) for s in available_applications) + assert all(s != 42 for s in available_applications) + assert len(available_applications) == 9 + # application_5 has multiply items with multiply supported systems + assert [str(s) for s in available_applications] == [ + "application_1", + "application_2", + "application_4", + "application_5", + "application_5", + "application_5A", + "application_5A", + "application_5B", + "application_5B", + ] + + +def test_get_unique_application_names() -> None: + """Test get_unique_application_names.""" + unique_names = get_unique_application_names() + + assert all(isinstance(s, str) for s in unique_names) + assert all(s for s in unique_names) + assert sorted(unique_names) == [ + "application_1", + "application_2", + "application_4", + "application_5", + "application_5A", + "application_5B", + ] + + +def test_get_application() -> None: + """Test get_application mocking get_resoures.""" + application = get_application("application_1") + if len(application) != 1: + pytest.fail("Unable to get application") + assert application[0].name == "application_1" + + application = get_application("unknown application") + assert len(application) == 0 + + +@pytest.mark.parametrize( + "source, call_count, expected_exception", + ( + ( + "archives/applications/application1.tar.gz", + 0, + pytest.raises( + Exception, match=r"Applications \[application_1\] are already installed" + ), + ), + ( + "various/applications/application_with_empty_config", + 0, + pytest.raises(Exception, match="No application definition found"), + ), + ( + "various/applications/application_with_wrong_config1", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ( + "various/applications/application_with_wrong_config2", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ( + "various/applications/application_with_wrong_config3", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ("various/applications/application_with_valid_config", 1, does_not_raise()), + ( + "archives/applications/application3.tar.gz", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ( + "applications/application1", + 0, + pytest.raises( + Exception, match=r"Applications \[application_1\] are already installed" + ), + ), + ( + "applications/application3", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ), +) +def test_install_application( + monkeypatch: Any, + test_resources_path: Path, + source: str, + call_count: int, + expected_exception: Any, +) -> None: + """Test application install from archive.""" + mock_create_destination_and_install = MagicMock() + monkeypatch.setattr( + "aiet.backend.application.create_destination_and_install", + mock_create_destination_and_install, + ) + + with expected_exception: + install_application(test_resources_path / source) + assert mock_create_destination_and_install.call_count == call_count + + +def test_remove_application(monkeypatch: Any) -> None: + """Test application removal.""" + mock_remove_backend = MagicMock() + monkeypatch.setattr("aiet.backend.application.remove_backend", mock_remove_backend) + + remove_application("some_application_directory") + mock_remove_backend.assert_called_once() + + +def test_application_config_without_commands() -> None: + """Test application config without commands.""" + config = ApplicationConfig(name="application") + application = Application(config) + # pylint: disable=use-implicit-booleaness-not-comparison + assert application.commands == {} + + +class TestApplication: + """Test for application class methods.""" + + def test___eq__(self) -> None: + """Test overloaded __eq__ method.""" + config = ApplicationConfig( + # Application + supported_systems=["system1", "system2"], + build_dir="build_dir", + # inherited from Backend + name="name", + description="description", + commands={}, + ) + application1 = Application(config) + application2 = Application(config) # Identical + assert application1 == application2 + + application3 = Application(config) # changed + # Change one single attribute so not equal, but same Type + setattr(application3, "supported_systems", ["somewhere/else"]) + assert application1 != application3 + + # different Type + application4 = "Not the Application you are looking for" + assert application1 != application4 + + application5 = Application(config) + # supported systems could be in any order + setattr(application5, "supported_systems", ["system2", "system1"]) + assert application1 == application5 + + def test_can_run_on(self) -> None: + """Test Application can run on.""" + config = ApplicationConfig(name="application", supported_systems=["System-A"]) + + application = Application(config) + assert application.can_run_on("System-A") + assert not application.can_run_on("System-B") + + applications = get_application("application_1", "System 1") + assert len(applications) == 1 + assert applications[0].can_run_on("System 1") + + def test_get_deploy_data(self, tmp_path: Path) -> None: + """Test Application can run on.""" + src, dest = "src", "dest" + config = ApplicationConfig( + name="application", deploy_data=[(src, dest)], config_location=tmp_path + ) + src_path = tmp_path / src + src_path.mkdir() + application = Application(config) + assert application.get_deploy_data() == [DataPaths(src_path, dest)] + + def test_get_deploy_data_no_config_location(self) -> None: + """Test that getting deploy data fails if no config location provided.""" + with pytest.raises( + Exception, match="Unable to get application .* config location" + ): + Application(ApplicationConfig(name="application")).get_deploy_data() + + def test_unable_to_create_application_without_name(self) -> None: + """Test that it is not possible to create application without name.""" + with pytest.raises(Exception, match="Name is empty"): + Application(ApplicationConfig()) + + def test_application_config_without_commands(self) -> None: + """Test application config without commands.""" + config = ApplicationConfig(name="application") + application = Application(config) + # pylint: disable=use-implicit-booleaness-not-comparison + assert application.commands == {} + + @pytest.mark.parametrize( + "config, expected_params", + ( + ( + ApplicationConfig( + name="application", + commands={"command": ["cmd {user_params:0} {user_params:1}"]}, + user_params={ + "command": [ + UserParamConfig( + name="--param1", description="param1", alias="param1" + ), + UserParamConfig( + name="--param2", description="param2", alias="param2" + ), + ] + }, + ), + [Param("--param1", "param1"), Param("--param2", "param2")], + ), + ( + ApplicationConfig( + name="application", + commands={"command": ["cmd {user_params:param1} {user_params:1}"]}, + user_params={ + "command": [ + UserParamConfig( + name="--param1", description="param1", alias="param1" + ), + UserParamConfig( + name="--param2", description="param2", alias="param2" + ), + ] + }, + ), + [Param("--param1", "param1"), Param("--param2", "param2")], + ), + ( + ApplicationConfig( + name="application", + commands={"command": ["cmd {user_params:param1}"]}, + user_params={ + "command": [ + UserParamConfig( + name="--param1", description="param1", alias="param1" + ), + UserParamConfig( + name="--param2", description="param2", alias="param2" + ), + ] + }, + ), + [Param("--param1", "param1")], + ), + ), + ) + def test_remove_unused_params( + self, config: ApplicationConfig, expected_params: List[Param] + ) -> None: + """Test mod remove_unused_parameter.""" + application = Application(config) + application.remove_unused_params() + assert application.commands["command"].params == expected_params + + +@pytest.mark.parametrize( + "config, expected_error", + ( + ( + ExtendedApplicationConfig(name="application"), + pytest.raises(Exception, match="No supported systems definition provided"), + ), + ( + ExtendedApplicationConfig( + name="application", supported_systems=[NamedExecutionConfig(name="")] + ), + pytest.raises( + Exception, + match="Unable to read supported system definition, name is missed", + ), + ), + ( + ExtendedApplicationConfig( + name="application", + supported_systems=[ + NamedExecutionConfig( + name="system", + commands={"command": ["cmd"]}, + user_params={"command": [UserParamConfig(name="param")]}, + ) + ], + commands={"command": ["cmd {user_params:0}"]}, + user_params={"command": [UserParamConfig(name="param")]}, + ), + pytest.raises( + Exception, match="Default parameters for command .* should have aliases" + ), + ), + ( + ExtendedApplicationConfig( + name="application", + supported_systems=[ + NamedExecutionConfig( + name="system", + commands={"command": ["cmd"]}, + user_params={"command": [UserParamConfig(name="param")]}, + ) + ], + commands={"command": ["cmd {user_params:0}"]}, + user_params={"command": [UserParamConfig(name="param", alias="param")]}, + ), + pytest.raises( + Exception, match="system parameters for command .* should have aliases" + ), + ), + ), +) +def test_load_application_exceptional_cases( + config: ExtendedApplicationConfig, expected_error: Any +) -> None: + """Test exceptional cases for application load function.""" + with expected_error: + load_applications(config) + + +def test_load_application() -> None: + """Test application load function. + + The main purpose of this test is to test configuration for application + for different systems. All configuration should be correctly + overridden if needed. + """ + application_5 = get_application("application_5") + assert len(application_5) == 2 + + default_commands = { + "build": Command(["default build command"]), + "run": Command(["default run command"]), + } + default_variables = {"var1": "value1", "var2": "value2"} + + application_5_0 = application_5[0] + assert application_5_0.build_dir == "default_build_dir" + assert application_5_0.supported_systems == ["System 1"] + assert application_5_0.commands == default_commands + assert application_5_0.variables == default_variables + assert application_5_0.lock is False + + application_5_1 = application_5[1] + assert application_5_1.build_dir == application_5_0.build_dir + assert application_5_1.supported_systems == ["System 2"] + assert application_5_1.commands == application_5_1.commands + assert application_5_1.variables == default_variables + + application_5a = get_application("application_5A") + assert len(application_5a) == 2 + + application_5a_0 = application_5a[0] + assert application_5a_0.supported_systems == ["System 1"] + assert application_5a_0.build_dir == "build_5A" + assert application_5a_0.commands == default_commands + assert application_5a_0.variables == {"var1": "new value1", "var2": "value2"} + assert application_5a_0.lock is False + + application_5a_1 = application_5a[1] + assert application_5a_1.supported_systems == ["System 2"] + assert application_5a_1.build_dir == "build" + assert application_5a_1.commands == { + "build": Command(["default build command"]), + "run": Command(["run command on system 2"]), + } + assert application_5a_1.variables == {"var1": "value1", "var2": "new value2"} + assert application_5a_1.lock is True + + application_5b = get_application("application_5B") + assert len(application_5b) == 2 + + application_5b_0 = application_5b[0] + assert application_5b_0.build_dir == "build_5B" + assert application_5b_0.supported_systems == ["System 1"] + assert application_5b_0.commands == { + "build": Command(["default build command with value for var1 System1"], []), + "run": Command(["default run command with value for var2 System1"]), + } + assert "non_used_command" not in application_5b_0.commands + + application_5b_1 = application_5b[1] + assert application_5b_1.build_dir == "build" + assert application_5b_1.supported_systems == ["System 2"] + assert application_5b_1.commands == { + "build": Command( + [ + "build command on system 2 with value" + " for var1 System2 {user_params:param1}" + ], + [ + Param( + "--param", + "Sample command param", + ["value1", "value2", "value3"], + "value1", + ) + ], + ), + "run": Command(["run command on system 2"], []), + } diff --git a/tests/aiet/test_backend_common.py b/tests/aiet/test_backend_common.py new file mode 100644 index 0000000..12c30ec --- /dev/null +++ b/tests/aiet/test_backend_common.py @@ -0,0 +1,486 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use,protected-access +"""Tests for the common backend module.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import cast +from typing import Dict +from typing import IO +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +from unittest.mock import MagicMock + +import pytest + +from aiet.backend.application import Application +from aiet.backend.common import Backend +from aiet.backend.common import BaseBackendConfig +from aiet.backend.common import Command +from aiet.backend.common import ConfigurationException +from aiet.backend.common import load_config +from aiet.backend.common import Param +from aiet.backend.common import parse_raw_parameter +from aiet.backend.common import remove_backend +from aiet.backend.config import ApplicationConfig +from aiet.backend.config import UserParamConfig +from aiet.backend.execution import ExecutionContext +from aiet.backend.execution import ParamResolver +from aiet.backend.system import System + + +@pytest.mark.parametrize( + "directory_name, expected_exception", + ( + ("some_dir", does_not_raise()), + (None, pytest.raises(Exception, match="No directory name provided")), + ), +) +def test_remove_backend( + monkeypatch: Any, directory_name: str, expected_exception: Any +) -> None: + """Test remove_backend function.""" + mock_remove_resource = MagicMock() + monkeypatch.setattr("aiet.backend.common.remove_resource", mock_remove_resource) + + with expected_exception: + remove_backend(directory_name, "applications") + + +@pytest.mark.parametrize( + "filename, expected_exception", + ( + ("application_config.json", does_not_raise()), + (None, pytest.raises(Exception, match="Unable to read config")), + ), +) +def test_load_config( + filename: str, expected_exception: Any, test_resources_path: Path, monkeypatch: Any +) -> None: + """Test load_config.""" + with expected_exception: + configs: List[Optional[Union[Path, IO[bytes]]]] = ( + [None] + if not filename + else [ + # Ignore pylint warning as 'with' can't be used inside of a + # generator expression. + # pylint: disable=consider-using-with + open(test_resources_path / filename, "rb"), + test_resources_path / filename, + ] + ) + for config in configs: + json_mock = MagicMock() + monkeypatch.setattr("aiet.backend.common.json.load", json_mock) + load_config(config) + json_mock.assert_called_once() + + +class TestBackend: + """Test Backend class.""" + + def test___repr__(self) -> None: + """Test the representation of Backend instance.""" + backend = Backend( + BaseBackendConfig(name="Testing name", description="Testing description") + ) + assert str(backend) == "Testing name" + + def test__eq__(self) -> None: + """Test equality method with different cases.""" + backend1 = Backend(BaseBackendConfig(name="name", description="description")) + backend1.commands = {"command": Command(["command"])} + + backend2 = Backend(BaseBackendConfig(name="name", description="description")) + backend2.commands = {"command": Command(["command"])} + + backend3 = Backend( + BaseBackendConfig( + name="Ben", description="This is not the Backend you are looking for" + ) + ) + backend3.commands = {"wave": Command(["wave hand"])} + + backend4 = "Foo" # checking not isinstance(backend4, Backend) + + assert backend1 == backend2 + assert backend1 != backend3 + assert backend1 != backend4 + + @pytest.mark.parametrize( + "parameter, valid", + [ + ("--choice-param dummy_value_1", True), + ("--choice-param wrong_value", False), + ("--open-param something", True), + ("--wrong-param value", False), + ], + ) + def test_validate_parameter( + self, parameter: str, valid: bool, test_resources_path: Path + ) -> None: + """Test validate_parameter.""" + config = cast( + List[ApplicationConfig], + load_config(test_resources_path / "hello_world.json"), + ) + # The application configuration is a list of configurations so we need + # only the first one + # Exercise the validate_parameter test using the Application classe which + # inherits from Backend. + application = Application(config[0]) + assert application.validate_parameter("run", parameter) == valid + + def test_validate_parameter_with_invalid_command( + self, test_resources_path: Path + ) -> None: + """Test validate_parameter with an invalid command_name.""" + config = cast( + List[ApplicationConfig], + load_config(test_resources_path / "hello_world.json"), + ) + application = Application(config[0]) + with pytest.raises(AttributeError) as err: + # command foo does not exist, so raise an error + application.validate_parameter("foo", "bar") + assert "Unknown command: 'foo'" in str(err.value) + + def test_build_command(self, monkeypatch: Any) -> None: + """Test command building.""" + config = { + "name": "test", + "commands": { + "build": ["build {user_params:0} {user_params:1}"], + "run": ["run {user_params:0}"], + "post_run": ["post_run {application_params:0} on {system_params:0}"], + "some_command": ["Command with {variables:var_A}"], + "empty_command": [""], + }, + "user_params": { + "build": [ + { + "name": "choice_param_0=", + "values": [1, 2, 3], + "default_value": 1, + }, + {"name": "choice_param_1", "values": [3, 4, 5], "default_value": 3}, + {"name": "choice_param_3", "values": [6, 7, 8]}, + ], + "run": [{"name": "flag_param_0"}], + }, + "variables": {"var_A": "value for variable A"}, + } + + monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock()) + application, system = Application(config), System(config) # type: ignore + context = ExecutionContext( + app=application, + app_params=[], + system=system, + system_params=[], + custom_deploy_data=[], + ) + + param_resolver = ParamResolver(context) + + cmd = application.build_command( + "build", ["choice_param_0=2", "choice_param_1=4"], param_resolver + ) + assert cmd == ["build choice_param_0=2 choice_param_1 4"] + + cmd = application.build_command("build", ["choice_param_0=2"], param_resolver) + assert cmd == ["build choice_param_0=2 choice_param_1 3"] + + cmd = application.build_command( + "build", ["choice_param_0=2", "choice_param_3=7"], param_resolver + ) + assert cmd == ["build choice_param_0=2 choice_param_1 3"] + + with pytest.raises( + ConfigurationException, match="Command 'foo' could not be found." + ): + application.build_command("foo", [""], param_resolver) + + cmd = application.build_command("some_command", [], param_resolver) + assert cmd == ["Command with value for variable A"] + + cmd = application.build_command("empty_command", [], param_resolver) + assert cmd == [""] + + @pytest.mark.parametrize("class_", [Application, System]) + def test_build_command_unknown_variable(self, class_: type) -> None: + """Test that unable to construct backend with unknown variable.""" + with pytest.raises(Exception, match="Unknown variable var1"): + config = {"name": "test", "commands": {"run": ["run {variables:var1}"]}} + class_(config) + + @pytest.mark.parametrize( + "class_, config, expected_output", + [ + ( + Application, + { + "name": "test", + "commands": { + "build": ["build {user_params:0} {user_params:1}"], + "run": ["run {user_params:0}"], + }, + "user_params": { + "build": [ + { + "name": "choice_param_0=", + "values": ["a", "b", "c"], + "default_value": "a", + "alias": "param_1", + }, + { + "name": "choice_param_1", + "values": ["a", "b", "c"], + "default_value": "a", + "alias": "param_2", + }, + {"name": "choice_param_3", "values": ["a", "b", "c"]}, + ], + "run": [{"name": "flag_param_0"}], + }, + }, + [ + ( + "b", + Param( + name="choice_param_0=", + description="", + values=["a", "b", "c"], + default_value="a", + alias="param_1", + ), + ), + ( + "a", + Param( + name="choice_param_1", + description="", + values=["a", "b", "c"], + default_value="a", + alias="param_2", + ), + ), + ( + "c", + Param( + name="choice_param_3", + description="", + values=["a", "b", "c"], + ), + ), + ], + ), + (System, {"name": "test"}, []), + ], + ) + def test_resolved_parameters( + self, + monkeypatch: Any, + class_: type, + config: Dict, + expected_output: List[Tuple[Optional[str], Param]], + ) -> None: + """Test command building.""" + monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock()) + backend = class_(config) + + params = backend.resolved_parameters( + "build", ["choice_param_0=b", "choice_param_3=c"] + ) + assert params == expected_output + + @pytest.mark.parametrize( + ["param_name", "user_param", "expected_value"], + [ + ( + "test_name", + "test_name=1234", + "1234", + ), # optional parameter using '=' + ( + "test_name", + "test_name 1234", + "1234", + ), # optional parameter using ' ' + ("test_name", "test_name", None), # flag + (None, "test_name=1234", "1234"), # positional parameter + ], + ) + def test_resolved_user_parameters( + self, param_name: str, user_param: str, expected_value: str + ) -> None: + """Test different variants to provide user parameters.""" + # A dummy config providing one backend config + config = { + "name": "test_backend", + "commands": { + "test": ["user_param:test_param"], + }, + "user_params": { + "test": [UserParamConfig(name=param_name, alias="test_name")], + }, + } + backend = Backend(cast(BaseBackendConfig, config)) + params = backend.resolved_parameters( + command_name="test", user_params=[user_param] + ) + assert len(params) == 1 + value, param = params[0] + assert param_name == param.name + assert expected_value == value + + @pytest.mark.parametrize( + "input_param,expected", + [ + ("--param=1", ("--param", "1")), + ("--param 1", ("--param", "1")), + ("--flag", ("--flag", None)), + ], + ) + def test__parse_raw_parameter( + self, input_param: str, expected: Tuple[str, Optional[str]] + ) -> None: + """Test internal method of parsing a single raw parameter.""" + assert parse_raw_parameter(input_param) == expected + + +class TestParam: + """Test Param class.""" + + def test__eq__(self) -> None: + """Test equality method with different cases.""" + param1 = Param(name="test", description="desc", values=["values"]) + param2 = Param(name="test", description="desc", values=["values"]) + param3 = Param(name="test1", description="desc", values=["values"]) + param4 = object() + + assert param1 == param2 + assert param1 != param3 + assert param1 != param4 + + def test_get_details(self) -> None: + """Test get_details() method.""" + param1 = Param(name="test", description="desc", values=["values"]) + assert param1.get_details() == { + "name": "test", + "values": ["values"], + "description": "desc", + } + + def test_invalid(self) -> None: + """Test invalid use cases for the Param class.""" + with pytest.raises( + ConfigurationException, + match="Either name, alias or both must be set to identify a parameter.", + ): + Param(name=None, description="desc", values=["values"]) + + +class TestCommand: + """Test Command class.""" + + def test_get_details(self) -> None: + """Test get_details() method.""" + param1 = Param(name="test", description="desc", values=["values"]) + command1 = Command(command_strings=["echo test"], params=[param1]) + assert command1.get_details() == { + "command_strings": ["echo test"], + "user_params": [ + {"name": "test", "values": ["values"], "description": "desc"} + ], + } + + def test__eq__(self) -> None: + """Test equality method with different cases.""" + param1 = Param("test", "desc", ["values"]) + param2 = Param("test1", "desc1", ["values1"]) + command1 = Command(command_strings=["echo test"], params=[param1]) + command2 = Command(command_strings=["echo test"], params=[param1]) + command3 = Command(command_strings=["echo test"]) + command4 = Command(command_strings=["echo test"], params=[param2]) + command5 = object() + + assert command1 == command2 + assert command1 != command3 + assert command1 != command4 + assert command1 != command5 + + @pytest.mark.parametrize( + "params, expected_error", + [ + [[], does_not_raise()], + [[Param("param", "param description", [])], does_not_raise()], + [ + [ + Param("param", "param description", [], None, "alias"), + Param("param", "param description", [], None), + ], + does_not_raise(), + ], + [ + [ + Param("param1", "param1 description", [], None, "alias1"), + Param("param2", "param2 description", [], None, "alias2"), + ], + does_not_raise(), + ], + [ + [ + Param("param", "param description", [], None, "alias"), + Param("param", "param description", [], None, "alias"), + ], + pytest.raises(ConfigurationException, match="Non unique aliases alias"), + ], + [ + [ + Param("alias", "param description", [], None, "alias1"), + Param("param", "param description", [], None, "alias"), + ], + pytest.raises( + ConfigurationException, + match="Aliases .* could not be used as parameter name", + ), + ], + [ + [ + Param("alias", "param description", [], None, "alias"), + Param("param1", "param1 description", [], None, "alias1"), + ], + does_not_raise(), + ], + [ + [ + Param("alias", "param description", [], None, "alias"), + Param("alias", "param1 description", [], None, "alias1"), + ], + pytest.raises( + ConfigurationException, + match="Aliases .* could not be used as parameter name", + ), + ], + [ + [ + Param("param1", "param1 description", [], None, "alias1"), + Param("param2", "param2 description", [], None, "alias1"), + Param("param3", "param3 description", [], None, "alias2"), + Param("param4", "param4 description", [], None, "alias2"), + ], + pytest.raises( + ConfigurationException, match="Non unique aliases alias1, alias2" + ), + ], + ], + ) + def test_validate_params(self, params: List[Param], expected_error: Any) -> None: + """Test command validation function.""" + with expected_error: + Command([], params) diff --git a/tests/aiet/test_backend_controller.py b/tests/aiet/test_backend_controller.py new file mode 100644 index 0000000..8836ec5 --- /dev/null +++ b/tests/aiet/test_backend_controller.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for system controller.""" +import csv +import os +import time +from pathlib import Path +from typing import Any + +import psutil +import pytest + +from aiet.backend.common import ConfigurationException +from aiet.backend.controller import SystemController +from aiet.backend.controller import SystemControllerSingleInstance +from aiet.utils.proc import ShellCommand + + +def get_system_controller(**kwargs: Any) -> SystemController: + """Get service controller.""" + single_instance = kwargs.get("single_instance", False) + if single_instance: + pid_file_path = kwargs.get("pid_file_path") + return SystemControllerSingleInstance(pid_file_path) + + return SystemController() + + +def test_service_controller() -> None: + """Test service controller functionality.""" + service_controller = get_system_controller() + + assert service_controller.get_output() == ("", "") + with pytest.raises(ConfigurationException, match="Wrong working directory"): + service_controller.start(["sleep 100"], Path("unknown")) + + service_controller.start(["sleep 100"], Path.cwd()) + assert service_controller.is_running() + + service_controller.stop(True) + assert not service_controller.is_running() + assert service_controller.get_output() == ("", "") + + service_controller.stop() + + with pytest.raises( + ConfigurationException, match="System should have only one command to run" + ): + service_controller.start(["sleep 100", "sleep 101"], Path.cwd()) + + with pytest.raises(ConfigurationException, match="No startup command provided"): + service_controller.start([""], Path.cwd()) + + +def test_service_controller_bad_configuration() -> None: + """Test service controller functionality for bad configuration.""" + with pytest.raises(Exception, match="No pid file path presented"): + service_controller = get_system_controller( + single_instance=True, pid_file_path=None + ) + service_controller.start(["sleep 100"], Path.cwd()) + + +def test_service_controller_writes_process_info_correctly(tmpdir: Any) -> None: + """Test that controller writes process info correctly.""" + pid_file = Path(tmpdir) / "test.pid" + + service_controller = get_system_controller( + single_instance=True, pid_file_path=Path(tmpdir) / "test.pid" + ) + + service_controller.start(["sleep 100"], Path.cwd()) + assert service_controller.is_running() + assert pid_file.is_file() + + with open(pid_file, "r", encoding="utf-8") as file: + csv_reader = csv.reader(file) + rows = list(csv_reader) + assert len(rows) == 1 + + name, *_ = rows[0] + assert name == "sleep" + + service_controller.stop() + assert pid_file.exists() + + +def test_service_controller_does_not_write_process_info_if_process_finishes( + tmpdir: Any, +) -> None: + """Test that controller does not write process info if process already finished.""" + pid_file = Path(tmpdir) / "test.pid" + service_controller = get_system_controller( + single_instance=True, pid_file_path=pid_file + ) + service_controller.is_running = lambda: False # type: ignore + service_controller.start(["echo hello"], Path.cwd()) + + assert not pid_file.exists() + + +def test_service_controller_searches_for_previous_instances_correctly( + tmpdir: Any, +) -> None: + """Test that controller searches for previous instances correctly.""" + pid_file = Path(tmpdir) / "test.pid" + command = ShellCommand().run("sleep", "100") + assert command.is_alive() + + pid = command.process.pid + process = psutil.Process(pid) + with open(pid_file, "w", encoding="utf-8") as file: + csv_writer = csv.writer(file) + csv_writer.writerow(("some_process", "some_program", "some_cwd", os.getpid())) + csv_writer.writerow((process.name(), process.exe(), process.cwd(), process.pid)) + csv_writer.writerow(("some_old_process", "not_running", "from_nowhere", 77777)) + + service_controller = get_system_controller( + single_instance=True, pid_file_path=pid_file + ) + service_controller.start(["sleep 100"], Path.cwd()) + # controller should stop this process as it is currently running and + # mentioned in pid file + assert not command.is_alive() + + service_controller.stop() + + +@pytest.mark.parametrize( + "executable", ["test_backend_run_script.sh", "test_backend_run"] +) +def test_service_controller_run_shell_script( + executable: str, test_resources_path: Path +) -> None: + """Test controller's ability to run shell scripts.""" + script_path = test_resources_path / "scripts" + + service_controller = get_system_controller() + + service_controller.start([executable], script_path) + + assert service_controller.is_running() + # give time for the command to produce output + time.sleep(2) + service_controller.stop(wait=True) + assert not service_controller.is_running() + stdout, stderr = service_controller.get_output() + assert stdout == "Hello from script\n" + assert stderr == "Oops!\n" + + +def test_service_controller_does_nothing_if_not_started(tmpdir: Any) -> None: + """Test that nothing happened if controller is not started.""" + service_controller = get_system_controller( + single_instance=True, pid_file_path=Path(tmpdir) / "test.pid" + ) + + assert not service_controller.is_running() + service_controller.stop() + assert not service_controller.is_running() diff --git a/tests/aiet/test_backend_execution.py b/tests/aiet/test_backend_execution.py new file mode 100644 index 0000000..8aa45f1 --- /dev/null +++ b/tests/aiet/test_backend_execution.py @@ -0,0 +1,526 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Test backend context module.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Dict +from typing import Optional +from unittest import mock +from unittest.mock import MagicMock + +import pytest +from sh import CommandNotFound + +from aiet.backend.application import Application +from aiet.backend.application import get_application +from aiet.backend.common import ConfigurationException +from aiet.backend.common import DataPaths +from aiet.backend.common import UserParamConfig +from aiet.backend.config import ApplicationConfig +from aiet.backend.config import LocalProtocolConfig +from aiet.backend.config import SystemConfig +from aiet.backend.execution import deploy_data +from aiet.backend.execution import execute_commands_locally +from aiet.backend.execution import ExecutionContext +from aiet.backend.execution import get_application_and_system +from aiet.backend.execution import get_application_by_name_and_system +from aiet.backend.execution import get_file_lock_path +from aiet.backend.execution import get_tool_by_system +from aiet.backend.execution import ParamResolver +from aiet.backend.execution import Reporter +from aiet.backend.execution import wait +from aiet.backend.output_parser import OutputParser +from aiet.backend.system import get_system +from aiet.backend.system import load_system +from aiet.backend.tool import get_tool +from aiet.utils.proc import CommandFailedException + + +def test_context_param_resolver(tmpdir: Any) -> None: + """Test parameter resolving.""" + system_config_location = Path(tmpdir) / "system" + system_config_location.mkdir() + + application_config_location = Path(tmpdir) / "application" + application_config_location.mkdir() + + ctx = ExecutionContext( + app=Application( + ApplicationConfig( + name="test_application", + description="Test application", + config_location=application_config_location, + build_dir="build-{application.name}-{system.name}", + commands={ + "run": [ + "run_command1 {user_params:0}", + "run_command2 {user_params:1}", + ] + }, + variables={"var_1": "value for var_1"}, + user_params={ + "run": [ + UserParamConfig( + name="--param1", + description="Param 1", + default_value="123", + alias="param_1", + ), + UserParamConfig( + name="--param2", description="Param 2", default_value="456" + ), + UserParamConfig( + name="--param3", description="Param 3", alias="param_3" + ), + UserParamConfig( + name="--param4=", + description="Param 4", + default_value="456", + alias="param_4", + ), + UserParamConfig( + description="Param 5", + default_value="789", + alias="param_5", + ), + ] + }, + ) + ), + app_params=["--param2=789"], + system=load_system( + SystemConfig( + name="test_system", + description="Test system", + config_location=system_config_location, + build_dir="build", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={ + "build": ["build_command1 {user_params:0}"], + "run": ["run_command {application.commands.run:1}"], + }, + variables={"var_1": "value for var_1"}, + user_params={ + "build": [ + UserParamConfig( + name="--param1", description="Param 1", default_value="aaa" + ), + UserParamConfig(name="--param2", description="Param 2"), + ] + }, + ) + ), + system_params=["--param1=bbb"], + custom_deploy_data=[], + ) + + param_resolver = ParamResolver(ctx) + expected_values = { + "application.name": "test_application", + "application.description": "Test application", + "application.config_dir": str(application_config_location), + "application.build_dir": "{}/build-test_application-test_system".format( + application_config_location + ), + "application.commands.run:0": "run_command1 --param1 123", + "application.commands.run.params:0": "123", + "application.commands.run.params:param_1": "123", + "application.commands.run:1": "run_command2 --param2 789", + "application.commands.run.params:1": "789", + "application.variables:var_1": "value for var_1", + "system.name": "test_system", + "system.description": "Test system", + "system.config_dir": str(system_config_location), + "system.commands.build:0": "build_command1 --param1 bbb", + "system.commands.run:0": "run_command run_command2 --param2 789", + "system.commands.build.params:0": "bbb", + "system.variables:var_1": "value for var_1", + } + + for param, value in expected_values.items(): + assert param_resolver(param) == value + + assert ctx.build_dir() == Path( + "{}/build-test_application-test_system".format(application_config_location) + ) + + expected_errors = { + "application.variables:var_2": pytest.raises( + Exception, match="Unknown variable var_2" + ), + "application.commands.clean:0": pytest.raises( + Exception, match="Command clean not found" + ), + "application.commands.run:2": pytest.raises( + Exception, match="Invalid index 2 for command run" + ), + "application.commands.run.params:5": pytest.raises( + Exception, match="Invalid parameter index 5 for command run" + ), + "application.commands.run.params:param_2": pytest.raises( + Exception, + match="No value for parameter with index or alias param_2 of command run", + ), + "UNKNOWN": pytest.raises( + Exception, match="Unable to resolve parameter UNKNOWN" + ), + "system.commands.build.params:1": pytest.raises( + Exception, + match="No value for parameter with index or alias 1 of command build", + ), + "system.commands.build:A": pytest.raises( + Exception, match="Bad command index A" + ), + "system.variables:var_2": pytest.raises( + Exception, match="Unknown variable var_2" + ), + } + for param, error in expected_errors.items(): + with error: + param_resolver(param) + + resolved_params = ctx.app.resolved_parameters("run", []) + expected_user_params = { + "user_params:0": "--param1 123", + "user_params:param_1": "--param1 123", + "user_params:2": "--param3", + "user_params:param_3": "--param3", + "user_params:3": "--param4=456", + "user_params:param_4": "--param4=456", + "user_params:param_5": "789", + } + for param, expected_value in expected_user_params.items(): + assert param_resolver(param, "run", resolved_params) == expected_value + + with pytest.raises( + Exception, match="Invalid index 5 for user params of command run" + ): + param_resolver("user_params:5", "run", resolved_params) + + with pytest.raises( + Exception, match="No user parameter for command 'run' with alias 'param_2'." + ): + param_resolver("user_params:param_2", "run", resolved_params) + + with pytest.raises(Exception, match="Unable to resolve user params"): + param_resolver("user_params:0", "", resolved_params) + + bad_ctx = ExecutionContext( + app=Application( + ApplicationConfig( + name="test_application", + config_location=application_config_location, + build_dir="build-{user_params:0}", + ) + ), + app_params=["--param2=789"], + system=load_system( + SystemConfig( + name="test_system", + description="Test system", + config_location=system_config_location, + build_dir="build-{system.commands.run.params:123}", + data_transfer=LocalProtocolConfig(protocol="local"), + ) + ), + system_params=["--param1=bbb"], + custom_deploy_data=[], + ) + param_resolver = ParamResolver(bad_ctx) + with pytest.raises(Exception, match="Unable to resolve user params"): + bad_ctx.build_dir() + + +# pylint: disable=too-many-arguments +@pytest.mark.parametrize( + "application_name, soft_lock, sys_lock, lock_dir, expected_error, expected_path", + ( + ( + "test_application", + True, + True, + Path("/tmp"), + does_not_raise(), + Path("/tmp/middleware_test_application_test_system.lock"), + ), + ( + "$$test_application$!:", + True, + True, + Path("/tmp"), + does_not_raise(), + Path("/tmp/middleware_test_application_test_system.lock"), + ), + ( + "test_application", + True, + True, + Path("unknown"), + pytest.raises( + Exception, match="Invalid directory unknown for lock files provided" + ), + None, + ), + ( + "test_application", + False, + True, + Path("/tmp"), + does_not_raise(), + Path("/tmp/middleware_test_system.lock"), + ), + ( + "test_application", + True, + False, + Path("/tmp"), + does_not_raise(), + Path("/tmp/middleware_test_application.lock"), + ), + ( + "test_application", + False, + False, + Path("/tmp"), + pytest.raises(Exception, match="No filename for lock provided"), + None, + ), + ), +) +def test_get_file_lock_path( + application_name: str, + soft_lock: bool, + sys_lock: bool, + lock_dir: Path, + expected_error: Any, + expected_path: Path, +) -> None: + """Test get_file_lock_path function.""" + with expected_error: + ctx = ExecutionContext( + app=Application(ApplicationConfig(name=application_name, lock=soft_lock)), + app_params=[], + system=load_system( + SystemConfig( + name="test_system", + lock=sys_lock, + data_transfer=LocalProtocolConfig(protocol="local"), + ) + ), + system_params=[], + custom_deploy_data=[], + ) + path = get_file_lock_path(ctx, lock_dir) + assert path == expected_path + + +def test_get_application_by_name_and_system(monkeypatch: Any) -> None: + """Test exceptional case for get_application_by_name_and_system.""" + monkeypatch.setattr( + "aiet.backend.execution.get_application", + MagicMock(return_value=[MagicMock(), MagicMock()]), + ) + + with pytest.raises( + ValueError, + match="Error during getting application test_application for the " + "system test_system", + ): + get_application_by_name_and_system("test_application", "test_system") + + +def test_get_application_and_system(monkeypatch: Any) -> None: + """Test exceptional case for get_application_and_system.""" + monkeypatch.setattr( + "aiet.backend.execution.get_system", MagicMock(return_value=None) + ) + + with pytest.raises(ValueError, match="System test_system is not found"): + get_application_and_system("test_application", "test_system") + + +def test_wait_function(monkeypatch: Any) -> None: + """Test wait function.""" + sleep_mock = MagicMock() + monkeypatch.setattr("time.sleep", sleep_mock) + wait(0.1) + sleep_mock.assert_called_once() + + +def test_deployment_execution_context() -> None: + """Test property 'is_deploy_needed' of the ExecutionContext.""" + ctx = ExecutionContext( + app=get_application("application_1")[0], + app_params=[], + system=get_system("System 1"), + system_params=[], + ) + assert not ctx.is_deploy_needed + deploy_data(ctx) # should be a NOP + + ctx = ExecutionContext( + app=get_application("application_1")[0], + app_params=[], + system=get_system("System 1"), + system_params=[], + custom_deploy_data=[DataPaths(Path("README.md"), ".")], + ) + assert ctx.is_deploy_needed + + ctx = ExecutionContext( + app=get_application("application_1")[0], + app_params=[], + system=None, + system_params=[], + ) + assert not ctx.is_deploy_needed + with pytest.raises(AssertionError): + deploy_data(ctx) + + ctx = ExecutionContext( + app=get_tool("tool_1")[0], + app_params=[], + system=None, + system_params=[], + ) + assert not ctx.is_deploy_needed + deploy_data(ctx) # should be a NOP + + +@pytest.mark.parametrize( + ["tool_name", "system_name", "exception"], + [ + ("vela", "Corstone-300: Cortex-M55+Ethos-U65", None), + ("unknown tool", "Corstone-300: Cortex-M55+Ethos-U65", ConfigurationException), + ("vela", "unknown system", ConfigurationException), + ("vela", None, ConfigurationException), + ], +) +def test_get_tool_by_system( + tool_name: str, system_name: Optional[str], exception: Optional[Any] +) -> None: + """Test exceptions thrown by function get_tool_by_system().""" + + def test() -> None: + """Test call of get_tool_by_system().""" + tool = get_tool_by_system(tool_name, system_name) + assert tool is not None + + if exception is None: + test() + else: + with pytest.raises(exception): + test() + + +class TestExecuteCommandsLocally: + """Test execute_commands_locally() function.""" + + @pytest.mark.parametrize( + "first_command, exception, expected_output", + ( + ( + "echo 'hello'", + None, + "Running: echo 'hello'\nhello\nRunning: echo 'goodbye'\ngoodbye\n", + ), + ( + "non-existent-command", + CommandNotFound, + "Running: non-existent-command\n", + ), + ("false", CommandFailedException, "Running: false\n"), + ), + ids=( + "runs_multiple_commands", + "stops_executing_on_non_existent_command", + "stops_executing_when_command_exits_with_error_code", + ), + ) + def test_execution( + self, + first_command: str, + exception: Any, + expected_output: str, + test_resources_path: Path, + capsys: Any, + ) -> None: + """Test expected behaviour of the function.""" + commands = [first_command, "echo 'goodbye'"] + cwd = test_resources_path + if exception is None: + execute_commands_locally(commands, cwd) + else: + with pytest.raises(exception): + execute_commands_locally(commands, cwd) + + captured = capsys.readouterr() + assert captured.out == expected_output + + def test_stops_executing_on_exception( + self, monkeypatch: Any, test_resources_path: Path + ) -> None: + """Ensure commands following an error-exit-code command don't run.""" + # Mock execute_command() function + execute_command_mock = mock.MagicMock() + monkeypatch.setattr("aiet.utils.proc.execute_command", execute_command_mock) + + # Mock Command object and assign as return value to execute_command() + cmd_mock = mock.MagicMock() + execute_command_mock.return_value = cmd_mock + + # Mock the terminate_command (speed up test) + terminate_command_mock = mock.MagicMock() + monkeypatch.setattr("aiet.utils.proc.terminate_command", terminate_command_mock) + + # Mock a thrown Exception and assign to Command().exit_code + exit_code_mock = mock.PropertyMock(side_effect=Exception("Exception.")) + type(cmd_mock).exit_code = exit_code_mock + + with pytest.raises(Exception, match="Exception."): + execute_commands_locally( + ["command_1", "command_2"], cwd=test_resources_path + ) + + # Assert only "command_1" was executed + assert execute_command_mock.call_count == 1 + + +def test_reporter(tmpdir: Any) -> None: + """Test class 'Reporter'.""" + ctx = ExecutionContext( + app=get_application("application_4")[0], + app_params=["--app=TestApp"], + system=get_system("System 4"), + system_params=[], + ) + assert ctx.system is not None + + class MockParser(OutputParser): + """Mock implementation of an output parser.""" + + def __init__(self, metrics: Dict[str, Any]) -> None: + """Set up the MockParser.""" + super().__init__(name="test") + self.metrics = metrics + + def __call__(self, output: bytearray) -> Dict[str, Any]: + """Return mock metrics (ignoring the given output).""" + return self.metrics + + metrics = {"Metric": 123, "AnotherMetric": 456} + reporter = Reporter( + parsers=[MockParser(metrics={key: val}) for key, val in metrics.items()], + ) + reporter.parse(bytearray()) + report = reporter.report(ctx) + assert report["system"]["name"] == ctx.system.name + assert report["system"]["params"] == {} + assert report["application"]["name"] == ctx.app.name + assert report["application"]["params"] == {"--app": "TestApp"} + assert report["test"]["metrics"] == metrics + report_file = Path(tmpdir) / "report.json" + reporter.save(report, report_file) + assert report_file.is_file() diff --git a/tests/aiet/test_backend_output_parser.py b/tests/aiet/test_backend_output_parser.py new file mode 100644 index 0000000..d659812 --- /dev/null +++ b/tests/aiet/test_backend_output_parser.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the output parsing.""" +import base64 +import json +from typing import Any +from typing import Dict + +import pytest + +from aiet.backend.output_parser import Base64OutputParser +from aiet.backend.output_parser import OutputParser +from aiet.backend.output_parser import RegexOutputParser + + +OUTPUT_MATCH_ALL = bytearray( + """ +String1: My awesome string! +String2: STRINGS_ARE_GREAT!!! +Int: 12 +Float: 3.14 +""", + encoding="utf-8", +) + +OUTPUT_NO_MATCH = bytearray( + """ +This contains no matches... +Test1234567890!"£$%^&*()_+@~{}[]/.,<>?| +""", + encoding="utf-8", +) + +OUTPUT_PARTIAL_MATCH = bytearray( + "String1: My awesome string!", + encoding="utf-8", +) + +REGEX_CONFIG = { + "FirstString": {"pattern": r"String1.*: (.*)", "type": "str"}, + "SecondString": {"pattern": r"String2.*: (.*)!!!", "type": "str"}, + "IntegerValue": {"pattern": r"Int.*: (.*)", "type": "int"}, + "FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"}, +} + +EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {} + +EXPECTED_METRICS_ALL = { + "FirstString": "My awesome string!", + "SecondString": "STRINGS_ARE_GREAT", + "IntegerValue": 12, + "FloatValue": 3.14, +} + +EXPECTED_METRICS_PARTIAL = { + "FirstString": "My awesome string!", +} + + +class TestRegexOutputParser: + """Collect tests for the RegexOutputParser.""" + + @staticmethod + @pytest.mark.parametrize( + ["output", "config", "expected_metrics"], + [ + (OUTPUT_MATCH_ALL, REGEX_CONFIG, EXPECTED_METRICS_ALL), + (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL), + (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL), + ( + OUTPUT_MATCH_ALL + OUTPUT_PARTIAL_MATCH, + REGEX_CONFIG, + EXPECTED_METRICS_ALL, + ), + (OUTPUT_NO_MATCH, REGEX_CONFIG, {}), + (OUTPUT_MATCH_ALL, EMPTY_REGEX_CONFIG, {}), + (bytearray(), EMPTY_REGEX_CONFIG, {}), + (bytearray(), REGEX_CONFIG, {}), + ], + ) + def test_parsing(output: bytearray, config: Dict, expected_metrics: Dict) -> None: + """ + Make sure the RegexOutputParser yields valid results. + + I.e. return an empty dict if either the input or the config is empty and + return the parsed metrics otherwise. + """ + parser = RegexOutputParser(name="Test", regex_config=config) + assert parser.name == "Test" + assert isinstance(parser, OutputParser) + res = parser(output) + assert res == expected_metrics + + @staticmethod + def test_unsupported_type() -> None: + """An unsupported type in the regex_config must raise an exception.""" + config = {"BrokenMetric": {"pattern": "(.*)", "type": "UNSUPPORTED_TYPE"}} + with pytest.raises(TypeError): + RegexOutputParser(name="Test", regex_config=config) + + @staticmethod + @pytest.mark.parametrize( + "config", + ( + {"TooManyGroups": {"pattern": r"(\w)(\d)", "type": "str"}}, + {"NoGroups": {"pattern": r"\W", "type": "str"}}, + ), + ) + def test_invalid_pattern(config: Dict) -> None: + """Exactly one capturing parenthesis is allowed in the regex pattern.""" + with pytest.raises(ValueError): + RegexOutputParser(name="Test", regex_config=config) + + +@pytest.mark.parametrize( + "expected_metrics", + [ + EXPECTED_METRICS_ALL, + EXPECTED_METRICS_PARTIAL, + ], +) +def test_base64_output_parser(expected_metrics: Dict) -> None: + """ + Make sure the Base64OutputParser yields valid results. + + I.e. return an empty dict if either the input or the config is empty and + return the parsed metrics otherwise. + """ + parser = Base64OutputParser(name="Test") + assert parser.name == "Test" + assert isinstance(parser, OutputParser) + + def create_base64_output(expected_metrics: Dict) -> bytearray: + json_str = json.dumps(expected_metrics, indent=4) + json_b64 = base64.b64encode(json_str.encode("utf-8")) + return ( + OUTPUT_MATCH_ALL # Should not be matched by the Base64OutputParser + + f"<{Base64OutputParser.TAG_NAME}>".encode("utf-8") + + bytearray(json_b64) + + f"".encode("utf-8") + + OUTPUT_NO_MATCH # Just to add some difficulty... + ) + + output = create_base64_output(expected_metrics) + res = parser(output) + assert len(res) == 1 + assert isinstance(res, dict) + for val in res.values(): + assert val == expected_metrics + + output = parser.filter_out_parsed_content(output) + assert output == (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH) diff --git a/tests/aiet/test_backend_protocol.py b/tests/aiet/test_backend_protocol.py new file mode 100644 index 0000000..2103238 --- /dev/null +++ b/tests/aiet/test_backend_protocol.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use,attribute-defined-outside-init,protected-access +"""Tests for the protocol backend module.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import paramiko +import pytest + +from aiet.backend.common import ConfigurationException +from aiet.backend.config import LocalProtocolConfig +from aiet.backend.protocol import CustomSFTPClient +from aiet.backend.protocol import LocalProtocol +from aiet.backend.protocol import ProtocolFactory +from aiet.backend.protocol import SSHProtocol + + +class TestProtocolFactory: + """Test ProtocolFactory class.""" + + @pytest.mark.parametrize( + "config, expected_class, exception", + [ + ( + { + "protocol": "ssh", + "username": "user", + "password": "pass", + "hostname": "hostname", + "port": "22", + }, + SSHProtocol, + does_not_raise(), + ), + ({"protocol": "local"}, LocalProtocol, does_not_raise()), + ( + {"protocol": "something"}, + None, + pytest.raises(Exception, match="Protocol not supported"), + ), + (None, None, pytest.raises(Exception, match="No protocol config provided")), + ], + ) + def test_get_protocol( + self, config: Any, expected_class: type, exception: Any + ) -> None: + """Test get_protocol method.""" + factory = ProtocolFactory() + with exception: + protocol = factory.get_protocol(config) + assert isinstance(protocol, expected_class) + + +class TestLocalProtocol: + """Test local protocol.""" + + def test_local_protocol_run_command(self) -> None: + """Test local protocol run command.""" + config = LocalProtocolConfig(protocol="local") + protocol = LocalProtocol(config, cwd=Path("/tmp")) + ret, stdout, stderr = protocol.run("pwd") + assert ret == 0 + assert stdout.decode("utf-8").strip() == "/tmp" + assert stderr.decode("utf-8") == "" + + def test_local_protocol_run_wrong_cwd(self) -> None: + """Execution should fail if wrong working directory provided.""" + config = LocalProtocolConfig(protocol="local") + protocol = LocalProtocol(config, cwd=Path("unknown_directory")) + with pytest.raises( + ConfigurationException, match="Wrong working directory unknown_directory" + ): + protocol.run("pwd") + + +class TestSSHProtocol: + """Test SSH protocol.""" + + @pytest.fixture(autouse=True) + def setup_method(self, monkeypatch: Any) -> None: + """Set up protocol mocks.""" + self.mock_ssh_client = MagicMock(spec=paramiko.client.SSHClient) + + self.mock_ssh_channel = ( + self.mock_ssh_client.get_transport.return_value.open_session.return_value + ) + self.mock_ssh_channel.mock_add_spec(spec=paramiko.channel.Channel) + self.mock_ssh_channel.exit_status_ready.side_effect = [False, True] + self.mock_ssh_channel.recv_exit_status.return_value = True + self.mock_ssh_channel.recv_ready.side_effect = [False, True] + self.mock_ssh_channel.recv_stderr_ready.side_effect = [False, True] + + monkeypatch.setattr( + "aiet.backend.protocol.paramiko.client.SSHClient", + MagicMock(return_value=self.mock_ssh_client), + ) + + self.mock_sftp_client = MagicMock(spec=CustomSFTPClient) + monkeypatch.setattr( + "aiet.backend.protocol.CustomSFTPClient.from_transport", + MagicMock(return_value=self.mock_sftp_client), + ) + + ssh_config = { + "protocol": "ssh", + "username": "user", + "password": "pass", + "hostname": "hostname", + "port": "22", + } + self.protocol = SSHProtocol(ssh_config) + + def test_unable_create_ssh_client(self, monkeypatch: Any) -> None: + """Test that command should fail if unable to create ssh client instance.""" + monkeypatch.setattr( + "aiet.backend.protocol.paramiko.client.SSHClient", + MagicMock(side_effect=OSError("Error!")), + ) + + with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"): + self.protocol.run("command_example", retry=False) + + def test_ssh_protocol_run_command(self) -> None: + """Test that command run via ssh successfully.""" + self.protocol.run("command_example") + self.mock_ssh_channel.exec_command.assert_called_once() + + def test_ssh_protocol_run_command_connect_failed(self) -> None: + """Test that if connection is not possible then correct exception is raised.""" + self.mock_ssh_client.connect.side_effect = OSError("Unable to connect") + self.mock_ssh_client.close.side_effect = Exception("Error!") + + with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"): + self.protocol.run("command_example", retry=False) + + def test_ssh_protocol_run_command_bad_transport(self) -> None: + """Test that command should fail if unable to get transport.""" + self.mock_ssh_client.get_transport.return_value = None + + with pytest.raises(Exception, match="Unable to get transport"): + self.protocol.run("command_example", retry=False) + + def test_ssh_protocol_deploy_command_file( + self, test_applications_path: Path + ) -> None: + """Test that files could be deployed over ssh.""" + file_for_deploy = test_applications_path / "readme.txt" + dest = "/tmp/dest" + + self.protocol.deploy(file_for_deploy, dest) + self.mock_sftp_client.put.assert_called_once_with(str(file_for_deploy), dest) + + def test_ssh_protocol_deploy_command_unknown_file(self) -> None: + """Test that deploy will fail if file does not exist.""" + with pytest.raises(Exception, match="Deploy error: file type not supported"): + self.protocol.deploy(Path("unknown_file"), "/tmp/dest") + + def test_ssh_protocol_deploy_command_bad_transport(self) -> None: + """Test that deploy should fail if unable to get transport.""" + self.mock_ssh_client.get_transport.return_value = None + + with pytest.raises(Exception, match="Unable to get transport"): + self.protocol.deploy(Path("some_file"), "/tmp/dest") + + def test_ssh_protocol_deploy_command_directory( + self, test_resources_path: Path + ) -> None: + """Test that directory could be deployed over ssh.""" + directory_for_deploy = test_resources_path / "scripts" + dest = "/tmp/dest" + + self.protocol.deploy(directory_for_deploy, dest) + self.mock_sftp_client.put_dir.assert_called_once_with( + directory_for_deploy, dest + ) + + @pytest.mark.parametrize("establish_connection", (True, False)) + def test_ssh_protocol_close(self, establish_connection: bool) -> None: + """Test protocol close operation.""" + if establish_connection: + self.protocol.establish_connection() + self.protocol.close() + + call_count = 1 if establish_connection else 0 + assert self.mock_ssh_channel.exec_command.call_count == call_count + + def test_connection_details(self) -> None: + """Test getting connection details.""" + assert self.protocol.connection_details() == ("hostname", 22) + + +class TestCustomSFTPClient: + """Test CustomSFTPClient class.""" + + @pytest.fixture(autouse=True) + def setup_method(self, monkeypatch: Any) -> None: + """Set up mocks for CustomSFTPClient instance.""" + self.mock_mkdir = MagicMock() + self.mock_put = MagicMock() + monkeypatch.setattr( + "aiet.backend.protocol.paramiko.SFTPClient.__init__", + MagicMock(return_value=None), + ) + monkeypatch.setattr( + "aiet.backend.protocol.paramiko.SFTPClient.mkdir", self.mock_mkdir + ) + monkeypatch.setattr( + "aiet.backend.protocol.paramiko.SFTPClient.put", self.mock_put + ) + + self.sftp_client = CustomSFTPClient(MagicMock()) + + def test_put_dir(self, test_systems_path: Path) -> None: + """Test deploying directory to remote host.""" + directory_for_deploy = test_systems_path / "system1" + + self.sftp_client.put_dir(directory_for_deploy, "/tmp/dest") + assert self.mock_put.call_count == 3 + assert self.mock_mkdir.call_count == 3 + + def test_mkdir(self) -> None: + """Test creating directory on remote host.""" + self.mock_mkdir.side_effect = IOError("Cannot create directory") + + self.sftp_client._mkdir("new_directory", ignore_existing=True) + + with pytest.raises(IOError, match="Cannot create directory"): + self.sftp_client._mkdir("new_directory", ignore_existing=False) diff --git a/tests/aiet/test_backend_source.py b/tests/aiet/test_backend_source.py new file mode 100644 index 0000000..13b2c6d --- /dev/null +++ b/tests/aiet/test_backend_source.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Tests for the source backend module.""" +from collections import Counter +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from aiet.backend.common import ConfigurationException +from aiet.backend.source import create_destination_and_install +from aiet.backend.source import DirectorySource +from aiet.backend.source import get_source +from aiet.backend.source import TarArchiveSource + + +def test_create_destination_and_install(test_systems_path: Path, tmpdir: Any) -> None: + """Test create_destination_and_install function.""" + system_directory = test_systems_path / "system1" + + dir_source = DirectorySource(system_directory) + resources = Path(tmpdir) + create_destination_and_install(dir_source, resources) + assert (resources / "system1").is_dir() + + +@patch("aiet.backend.source.DirectorySource.create_destination", return_value=False) +def test_create_destination_and_install_if_dest_creation_not_required( + mock_ds_create_destination: Any, tmpdir: Any +) -> None: + """Test create_destination_and_install function.""" + dir_source = DirectorySource(Path("unknown")) + resources = Path(tmpdir) + with pytest.raises(Exception): + create_destination_and_install(dir_source, resources) + + mock_ds_create_destination.assert_called_once() + + +def test_create_destination_and_install_if_installation_fails(tmpdir: Any) -> None: + """Test create_destination_and_install function if installation fails.""" + dir_source = DirectorySource(Path("unknown")) + resources = Path(tmpdir) + with pytest.raises(Exception, match="Directory .* does not exist"): + create_destination_and_install(dir_source, resources) + assert not (resources / "unknown").exists() + assert resources.exists() + + +def test_create_destination_and_install_if_name_is_empty() -> None: + """Test create_destination_and_install function fails if source name is empty.""" + source = MagicMock() + source.create_destination.return_value = True + source.name.return_value = None + + with pytest.raises(Exception, match="Unable to get source name"): + create_destination_and_install(source, Path("some_path")) + + source.install_into.assert_not_called() + + +@pytest.mark.parametrize( + "source_path, expected_class, expected_error", + [ + (Path("applications/application1/"), DirectorySource, does_not_raise()), + ( + Path("archives/applications/application1.tar.gz"), + TarArchiveSource, + does_not_raise(), + ), + ( + Path("doesnt/exist"), + None, + pytest.raises( + ConfigurationException, match="Unable to read .*doesnt/exist" + ), + ), + ], +) +def test_get_source( + source_path: Path, + expected_class: Any, + expected_error: Any, + test_resources_path: Path, +) -> None: + """Test get_source function.""" + with expected_error: + full_source_path = test_resources_path / source_path + source = get_source(full_source_path) + assert isinstance(source, expected_class) + + +class TestDirectorySource: + """Test DirectorySource class.""" + + @pytest.mark.parametrize( + "directory, name", + [ + (Path("/some/path/some_system"), "some_system"), + (Path("some_system"), "some_system"), + ], + ) + def test_name(self, directory: Path, name: str) -> None: + """Test getting source name.""" + assert DirectorySource(directory).name() == name + + def test_install_into(self, test_systems_path: Path, tmpdir: Any) -> None: + """Test install directory into destination.""" + system_directory = test_systems_path / "system1" + + dir_source = DirectorySource(system_directory) + with pytest.raises(Exception, match="Wrong destination .*"): + dir_source.install_into(Path("unknown_destination")) + + tmpdir_path = Path(tmpdir) + dir_source.install_into(tmpdir_path) + source_files = [f.name for f in system_directory.iterdir()] + dest_files = [f.name for f in tmpdir_path.iterdir()] + assert Counter(source_files) == Counter(dest_files) + + def test_install_into_unknown_source_directory(self, tmpdir: Any) -> None: + """Test install system from unknown directory.""" + with pytest.raises(Exception, match="Directory .* does not exist"): + DirectorySource(Path("unknown_directory")).install_into(Path(tmpdir)) + + +class TestTarArchiveSource: + """Test TarArchiveSource class.""" + + @pytest.mark.parametrize( + "archive, name", + [ + (Path("some_archive.tgz"), "some_archive"), + (Path("some_archive.tar.gz"), "some_archive"), + (Path("some_archive"), "some_archive"), + ("archives/systems/system1.tar.gz", "system1"), + ("archives/systems/system1_dir.tar.gz", "system1"), + ], + ) + def test_name(self, test_resources_path: Path, archive: Path, name: str) -> None: + """Test getting source name.""" + assert TarArchiveSource(test_resources_path / archive).name() == name + + def test_install_into(self, test_resources_path: Path, tmpdir: Any) -> None: + """Test install archive into destination.""" + system_archive = test_resources_path / "archives/systems/system1.tar.gz" + + tar_source = TarArchiveSource(system_archive) + with pytest.raises(Exception, match="Wrong destination .*"): + tar_source.install_into(Path("unknown_destination")) + + tmpdir_path = Path(tmpdir) + tar_source.install_into(tmpdir_path) + source_files = [ + "aiet-config.json.license", + "aiet-config.json", + "system_artifact", + ] + dest_files = [f.name for f in tmpdir_path.iterdir()] + assert Counter(source_files) == Counter(dest_files) + + def test_install_into_unknown_source_archive(self, tmpdir: Any) -> None: + """Test install unknown source archive.""" + with pytest.raises(Exception, match="File .* does not exist"): + TarArchiveSource(Path("unknown.tar.gz")).install_into(Path(tmpdir)) + + def test_install_into_unsupported_source_archive(self, tmpdir: Any) -> None: + """Test install unsupported file type.""" + plain_text_file = Path(tmpdir) / "test_file" + plain_text_file.write_text("Not a system config") + + with pytest.raises(Exception, match="Unsupported archive type .*"): + TarArchiveSource(plain_text_file).install_into(Path(tmpdir)) + + def test_lazy_property_init(self, test_resources_path: Path) -> None: + """Test that class properties initialized correctly.""" + system_archive = test_resources_path / "archives/systems/system1.tar.gz" + + tar_source = TarArchiveSource(system_archive) + assert tar_source.name() == "system1" + assert tar_source.config() is not None + assert tar_source.create_destination() + + tar_source = TarArchiveSource(system_archive) + assert tar_source.config() is not None + assert tar_source.create_destination() + assert tar_source.name() == "system1" + + def test_create_destination_property(self, test_resources_path: Path) -> None: + """Test create_destination property filled correctly for different archives.""" + system_archive1 = test_resources_path / "archives/systems/system1.tar.gz" + system_archive2 = test_resources_path / "archives/systems/system1_dir.tar.gz" + + assert TarArchiveSource(system_archive1).create_destination() + assert not TarArchiveSource(system_archive2).create_destination() diff --git a/tests/aiet/test_backend_system.py b/tests/aiet/test_backend_system.py new file mode 100644 index 0000000..a581547 --- /dev/null +++ b/tests/aiet/test_backend_system.py @@ -0,0 +1,536 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for system backend.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from unittest.mock import MagicMock + +import pytest + +from aiet.backend.common import Command +from aiet.backend.common import ConfigurationException +from aiet.backend.common import Param +from aiet.backend.common import UserParamConfig +from aiet.backend.config import LocalProtocolConfig +from aiet.backend.config import ProtocolConfig +from aiet.backend.config import SSHConfig +from aiet.backend.config import SystemConfig +from aiet.backend.controller import SystemController +from aiet.backend.controller import SystemControllerSingleInstance +from aiet.backend.protocol import LocalProtocol +from aiet.backend.protocol import SSHProtocol +from aiet.backend.protocol import SupportsClose +from aiet.backend.protocol import SupportsDeploy +from aiet.backend.system import ControlledSystem +from aiet.backend.system import get_available_systems +from aiet.backend.system import get_controller +from aiet.backend.system import get_system +from aiet.backend.system import install_system +from aiet.backend.system import load_system +from aiet.backend.system import remove_system +from aiet.backend.system import StandaloneSystem +from aiet.backend.system import System + + +def dummy_resolver( + values: Optional[Dict[str, str]] = None +) -> Callable[[str, str, List[Tuple[Optional[str], Param]]], str]: + """Return dummy parameter resolver implementation.""" + # pylint: disable=unused-argument + def resolver( + param: str, cmd: str, param_values: List[Tuple[Optional[str], Param]] + ) -> str: + """Implement dummy parameter resolver.""" + return values.get(param, "") if values else "" + + return resolver + + +def test_get_available_systems() -> None: + """Test get_available_systems mocking get_resources.""" + available_systems = get_available_systems() + assert all(isinstance(s, System) for s in available_systems) + assert len(available_systems) == 3 + assert [str(s) for s in available_systems] == ["System 1", "System 2", "System 4"] + + +def test_get_system() -> None: + """Test get_system.""" + system1 = get_system("System 1") + assert isinstance(system1, ControlledSystem) + assert system1.connectable is True + assert system1.connection_details() == ("localhost", 8021) + assert system1.name == "System 1" + + system2 = get_system("System 2") + # check that comparison with object of another type returns false + assert system1 != 42 + assert system1 != system2 + + system = get_system("Unknown system") + assert system is None + + +@pytest.mark.parametrize( + "source, call_count, exception_type", + ( + ( + "archives/systems/system1.tar.gz", + 0, + pytest.raises(Exception, match="Systems .* are already installed"), + ), + ( + "archives/systems/system3.tar.gz", + 0, + pytest.raises(Exception, match="Unable to read system definition"), + ), + ( + "systems/system1", + 0, + pytest.raises(Exception, match="Systems .* are already installed"), + ), + ( + "systems/system3", + 0, + pytest.raises(Exception, match="Unable to read system definition"), + ), + ("unknown_path", 0, pytest.raises(Exception, match="Unable to read")), + ( + "various/systems/system_with_empty_config", + 0, + pytest.raises(Exception, match="No system definition found"), + ), + ("various/systems/system_with_valid_config", 1, does_not_raise()), + ), +) +def test_install_system( + monkeypatch: Any, + test_resources_path: Path, + source: str, + call_count: int, + exception_type: Any, +) -> None: + """Test system installation from archive.""" + mock_create_destination_and_install = MagicMock() + monkeypatch.setattr( + "aiet.backend.system.create_destination_and_install", + mock_create_destination_and_install, + ) + + with exception_type: + install_system(test_resources_path / source) + + assert mock_create_destination_and_install.call_count == call_count + + +def test_remove_system(monkeypatch: Any) -> None: + """Test system removal.""" + mock_remove_backend = MagicMock() + monkeypatch.setattr("aiet.backend.system.remove_backend", mock_remove_backend) + remove_system("some_system_dir") + mock_remove_backend.assert_called_once() + + +def test_system(monkeypatch: Any) -> None: + """Test the System class.""" + config = SystemConfig(name="System 1") + monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock()) + system = System(config) + assert str(system) == "System 1" + assert system.name == "System 1" + + +def test_system_with_empty_parameter_name() -> None: + """Test that configuration fails if parameter name is empty.""" + bad_config = SystemConfig( + name="System 1", + commands={"run": ["run"]}, + user_params={"run": [{"name": "", "values": ["1", "2", "3"]}]}, + ) + with pytest.raises(Exception, match="Parameter has an empty 'name' attribute."): + System(bad_config) + + +def test_system_standalone_run() -> None: + """Test run operation for standalone system.""" + system = get_system("System 4") + assert isinstance(system, StandaloneSystem) + + with pytest.raises( + ConfigurationException, match="System .* does not support connections" + ): + system.connection_details() + + with pytest.raises( + ConfigurationException, match="System .* does not support connections" + ): + system.establish_connection() + + assert system.connectable is False + + system.run("echo 'application run'") + + +@pytest.mark.parametrize( + "system_name, expected_value", [("System 1", True), ("System 4", False)] +) +def test_system_supports_deploy(system_name: str, expected_value: bool) -> None: + """Test system property supports_deploy.""" + system = get_system(system_name) + if system is None: + pytest.fail("Unable to get system {}".format(system_name)) + assert system.supports_deploy == expected_value + + +@pytest.mark.parametrize( + "mock_protocol", + [ + MagicMock(spec=SSHProtocol), + MagicMock( + spec=SSHProtocol, + **{"close.side_effect": ValueError("Unable to close protocol")} + ), + MagicMock(spec=LocalProtocol), + ], +) +def test_system_start_and_stop(monkeypatch: Any, mock_protocol: MagicMock) -> None: + """Test system start, run commands and stop.""" + monkeypatch.setattr( + "aiet.backend.system.ProtocolFactory.get_protocol", + MagicMock(return_value=mock_protocol), + ) + + system = get_system("System 1") + if system is None: + pytest.fail("Unable to get system") + assert isinstance(system, ControlledSystem) + + with pytest.raises(Exception, match="System has not been started"): + system.stop() + + assert not system.is_running() + assert system.get_output() == ("", "") + system.start(["sleep 10"], False) + assert system.is_running() + system.stop(wait=True) + assert not system.is_running() + assert system.get_output() == ("", "") + + if isinstance(mock_protocol, SupportsClose): + mock_protocol.close.assert_called_once() + + if isinstance(mock_protocol, SSHProtocol): + system.establish_connection() + + +def test_system_start_no_config_location() -> None: + """Test that system without config location could not start.""" + system = load_system( + SystemConfig( + name="test", + data_transfer=SSHConfig( + protocol="ssh", + username="user", + password="user", + hostname="localhost", + port="123", + ), + ) + ) + + assert isinstance(system, ControlledSystem) + with pytest.raises( + ConfigurationException, match="System test has wrong config location" + ): + system.start(["sleep 100"]) + + +@pytest.mark.parametrize( + "config, expected_class, expected_error", + [ + ( + SystemConfig( + name="test", + data_transfer=SSHConfig( + protocol="ssh", + username="user", + password="user", + hostname="localhost", + port="123", + ), + ), + ControlledSystem, + does_not_raise(), + ), + ( + SystemConfig( + name="test", data_transfer=LocalProtocolConfig(protocol="local") + ), + StandaloneSystem, + does_not_raise(), + ), + ( + SystemConfig( + name="test", + data_transfer=ProtocolConfig(protocol="cool_protocol"), # type: ignore + ), + None, + pytest.raises( + Exception, match="Unsupported execution type for protocol cool_protocol" + ), + ), + ], +) +def test_load_system( + config: SystemConfig, expected_class: type, expected_error: Any +) -> None: + """Test load_system function.""" + if not expected_class: + with expected_error: + load_system(config) + else: + system = load_system(config) + assert isinstance(system, expected_class) + + +def test_load_system_populate_shared_params() -> None: + """Test shared parameters population.""" + with pytest.raises(Exception, match="All shared parameters should have aliases"): + load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + user_params={ + "shared": [ + UserParamConfig( + name="--shared_param1", + description="Shared parameter", + values=["1", "2", "3"], + default_value="1", + ) + ] + }, + ) + ) + + with pytest.raises( + Exception, match="All parameters for command run should have aliases" + ): + load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + user_params={ + "shared": [ + UserParamConfig( + name="--shared_param1", + description="Shared parameter", + values=["1", "2", "3"], + default_value="1", + alias="shared_param1", + ) + ], + "run": [ + UserParamConfig( + name="--run_param1", + description="Run specific parameter", + values=["1", "2", "3"], + default_value="2", + ) + ], + }, + ) + ) + system0 = load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["run_command"]}, + user_params={ + "shared": [], + "run": [ + UserParamConfig( + name="--run_param1", + description="Run specific parameter", + values=["1", "2", "3"], + default_value="2", + alias="run_param1", + ) + ], + }, + ) + ) + assert len(system0.commands) == 1 + run_command1 = system0.commands["run"] + assert run_command1 == Command( + ["run_command"], + [ + Param( + "--run_param1", + "Run specific parameter", + ["1", "2", "3"], + "2", + "run_param1", + ) + ], + ) + + system1 = load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + user_params={ + "shared": [ + UserParamConfig( + name="--shared_param1", + description="Shared parameter", + values=["1", "2", "3"], + default_value="1", + alias="shared_param1", + ) + ], + "run": [ + UserParamConfig( + name="--run_param1", + description="Run specific parameter", + values=["1", "2", "3"], + default_value="2", + alias="run_param1", + ) + ], + }, + ) + ) + assert len(system1.commands) == 2 + build_command1 = system1.commands["build"] + assert build_command1 == Command( + [], + [ + Param( + "--shared_param1", + "Shared parameter", + ["1", "2", "3"], + "1", + "shared_param1", + ) + ], + ) + + run_command1 = system1.commands["run"] + assert run_command1 == Command( + [], + [ + Param( + "--shared_param1", + "Shared parameter", + ["1", "2", "3"], + "1", + "shared_param1", + ), + Param( + "--run_param1", + "Run specific parameter", + ["1", "2", "3"], + "2", + "run_param1", + ), + ], + ) + + system2 = load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"build": ["build_command"]}, + user_params={ + "shared": [ + UserParamConfig( + name="--shared_param1", + description="Shared parameter", + values=["1", "2", "3"], + default_value="1", + alias="shared_param1", + ) + ], + "run": [ + UserParamConfig( + name="--run_param1", + description="Run specific parameter", + values=["1", "2", "3"], + default_value="2", + alias="run_param1", + ) + ], + }, + ) + ) + assert len(system2.commands) == 2 + build_command2 = system2.commands["build"] + assert build_command2 == Command( + ["build_command"], + [ + Param( + "--shared_param1", + "Shared parameter", + ["1", "2", "3"], + "1", + "shared_param1", + ) + ], + ) + + run_command2 = system1.commands["run"] + assert run_command2 == Command( + [], + [ + Param( + "--shared_param1", + "Shared parameter", + ["1", "2", "3"], + "1", + "shared_param1", + ), + Param( + "--run_param1", + "Run specific parameter", + ["1", "2", "3"], + "2", + "run_param1", + ), + ], + ) + + +@pytest.mark.parametrize( + "mock_protocol, expected_call_count", + [(MagicMock(spec=SupportsDeploy), 1), (MagicMock(), 0)], +) +def test_system_deploy_data( + monkeypatch: Any, mock_protocol: MagicMock, expected_call_count: int +) -> None: + """Test deploy data functionality.""" + monkeypatch.setattr( + "aiet.backend.system.ProtocolFactory.get_protocol", + MagicMock(return_value=mock_protocol), + ) + + system = ControlledSystem(SystemConfig(name="test")) + system.deploy(Path("some_file"), "some_dest") + + assert mock_protocol.deploy.call_count == expected_call_count + + +@pytest.mark.parametrize( + "single_instance, controller_class", + ((False, SystemController), (True, SystemControllerSingleInstance)), +) +def test_get_controller(single_instance: bool, controller_class: type) -> None: + """Test function get_controller.""" + controller = get_controller(single_instance) + assert isinstance(controller, controller_class) diff --git a/tests/aiet/test_backend_tool.py b/tests/aiet/test_backend_tool.py new file mode 100644 index 0000000..fd5960d --- /dev/null +++ b/tests/aiet/test_backend_tool.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Tests for the tool backend.""" +from collections import Counter + +import pytest + +from aiet.backend.common import ConfigurationException +from aiet.backend.config import ToolConfig +from aiet.backend.tool import get_available_tool_directory_names +from aiet.backend.tool import get_available_tools +from aiet.backend.tool import get_tool +from aiet.backend.tool import Tool + + +def test_get_available_tool_directory_names() -> None: + """Test get_available_tools mocking get_resources.""" + directory_names = get_available_tool_directory_names() + assert Counter(directory_names) == Counter(["tool1", "tool2", "vela"]) + + +def test_get_available_tools() -> None: + """Test get_available_tools mocking get_resources.""" + available_tools = get_available_tools() + expected_tool_names = sorted( + [ + "tool_1", + "tool_2", + "vela", + "vela", + "vela", + ] + ) + + assert all(isinstance(s, Tool) for s in available_tools) + assert all(s != 42 for s in available_tools) + assert any(s == available_tools[0] for s in available_tools) + assert len(available_tools) == len(expected_tool_names) + available_tool_names = sorted(str(s) for s in available_tools) + assert available_tool_names == expected_tool_names + + +def test_get_tool() -> None: + """Test get_tool mocking get_resoures.""" + tools = get_tool("tool_1") + assert len(tools) == 1 + tool = tools[0] + assert tool is not None + assert isinstance(tool, Tool) + assert tool.name == "tool_1" + + tools = get_tool("unknown tool") + assert not tools + + +def test_tool_creation() -> None: + """Test edge cases when creating a Tool instance.""" + with pytest.raises(ConfigurationException): + Tool(ToolConfig(name="test", commands={"test": []})) # no 'run' command diff --git a/tests/aiet/test_check_model.py b/tests/aiet/test_check_model.py new file mode 100644 index 0000000..4eafe59 --- /dev/null +++ b/tests/aiet/test_check_model.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=redefined-outer-name,no-self-use +"""Module for testing check_model.py script.""" +from pathlib import Path +from typing import Any + +import pytest +from ethosu.vela.tflite.Model import Model +from ethosu.vela.tflite.OperatorCode import OperatorCode + +from aiet.cli.common import InvalidTFLiteFileError +from aiet.cli.common import ModelOptimisedException +from aiet.resources.tools.vela.check_model import check_custom_codes_for_ethosu +from aiet.resources.tools.vela.check_model import check_model +from aiet.resources.tools.vela.check_model import get_custom_codes_from_operators +from aiet.resources.tools.vela.check_model import get_model_from_file +from aiet.resources.tools.vela.check_model import get_operators_from_model +from aiet.resources.tools.vela.check_model import is_vela_optimised + + +@pytest.fixture(scope="session") +def optimised_tflite_model( + optimised_input_model_file: Path, +) -> Model: + """Return Model instance read from a Vela-optimised TFLite file.""" + return get_model_from_file(optimised_input_model_file) + + +@pytest.fixture(scope="session") +def non_optimised_tflite_model( + non_optimised_input_model_file: Path, +) -> Model: + """Return Model instance read from a Vela-optimised TFLite file.""" + return get_model_from_file(non_optimised_input_model_file) + + +class TestIsVelaOptimised: + """Test class for is_vela_optimised() function.""" + + def test_return_true_when_input_is_optimised( + self, + optimised_tflite_model: Model, + ) -> None: + """Verify True returned when input is optimised model.""" + output = is_vela_optimised(optimised_tflite_model) + + assert output is True + + def test_return_false_when_input_is_not_optimised( + self, + non_optimised_tflite_model: Model, + ) -> None: + """Verify False returned when input is non-optimised model.""" + output = is_vela_optimised(non_optimised_tflite_model) + + assert output is False + + +def test_get_operator_list_returns_correct_instances( + optimised_tflite_model: Model, +) -> None: + """Verify list of OperatorCode instances returned by get_operator_list().""" + operator_list = get_operators_from_model(optimised_tflite_model) + + assert all(isinstance(operator, OperatorCode) for operator in operator_list) + + +class TestGetCustomCodesFromOperators: + """Test the get_custom_codes_from_operators() function.""" + + def test_returns_empty_list_when_input_operators_have_no_custom_codes( + self, monkeypatch: Any + ) -> None: + """Verify function returns empty list when operators have no custom codes.""" + # Mock OperatorCode.CustomCode() function to return None + monkeypatch.setattr( + "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", lambda _: None + ) + + operators = [OperatorCode()] * 3 + + custom_codes = get_custom_codes_from_operators(operators) + + assert custom_codes == [] + + def test_returns_custom_codes_when_input_operators_have_custom_codes( + self, monkeypatch: Any + ) -> None: + """Verify list of bytes objects returned representing the CustomCodes.""" + # Mock OperatorCode.CustomCode() function to return a byte string + monkeypatch.setattr( + "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", + lambda _: b"custom-code", + ) + + operators = [OperatorCode()] * 3 + + custom_codes = get_custom_codes_from_operators(operators) + + assert custom_codes == [b"custom-code", b"custom-code", b"custom-code"] + + +@pytest.mark.parametrize( + "custom_codes, expected_output", + [ + ([b"ethos-u", b"something else"], True), + ([b"custom-code-1", b"custom-code-2"], False), + ], +) +def test_check_list_for_ethosu(custom_codes: list, expected_output: bool) -> None: + """Verify function detects 'ethos-u' bytes in the input list.""" + output = check_custom_codes_for_ethosu(custom_codes) + assert output is expected_output + + +class TestGetModelFromFile: + """Test the get_model_from_file() function.""" + + def test_error_raised_when_input_is_invalid_model_file( + self, + invalid_input_model_file: Path, + ) -> None: + """Verify error thrown when an invalid model file is given.""" + with pytest.raises(InvalidTFLiteFileError): + get_model_from_file(invalid_input_model_file) + + def test_model_instance_returned_when_input_is_valid_model_file( + self, + optimised_input_model_file: Path, + ) -> None: + """Verify file is read successfully and returns model instance.""" + tflite_model = get_model_from_file(optimised_input_model_file) + + assert isinstance(tflite_model, Model) + + +class TestCheckModel: + """Test the check_model() function.""" + + def test_check_model_with_non_optimised_input( + self, + non_optimised_input_model_file: Path, + ) -> None: + """Verify no error occurs for a valid input file.""" + check_model(non_optimised_input_model_file) + + def test_check_model_with_optimised_input( + self, + optimised_input_model_file: Path, + ) -> None: + """Verify that the right exception is raised with already optimised input.""" + with pytest.raises(ModelOptimisedException): + check_model(optimised_input_model_file) + + def test_check_model_with_invalid_input( + self, + invalid_input_model_file: Path, + ) -> None: + """Verify that an exception is raised with invalid input.""" + with pytest.raises(Exception): + check_model(invalid_input_model_file) diff --git a/tests/aiet/test_cli.py b/tests/aiet/test_cli.py new file mode 100644 index 0000000..e8589fa --- /dev/null +++ b/tests/aiet/test_cli.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for testing CLI top command.""" +from typing import Any +from unittest.mock import ANY +from unittest.mock import MagicMock + +from click.testing import CliRunner + +from aiet.cli import cli + + +def test_cli(cli_runner: CliRunner) -> None: + """Test CLI top level command.""" + result = cli_runner.invoke(cli) + assert result.exit_code == 0 + assert "system" in cli.commands + assert "application" in cli.commands + + +def test_cli_version(cli_runner: CliRunner) -> None: + """Test version option.""" + result = cli_runner.invoke(cli, ["--version"]) + assert result.exit_code == 0 + assert "version" in result.output + + +def test_cli_verbose(cli_runner: CliRunner, monkeypatch: Any) -> None: + """Test verbose option.""" + with monkeypatch.context() as mock_context: + mock = MagicMock() + # params[1] is the verbose option and we need to replace the + # callback with a mock object + mock_context.setattr(cli.params[1], "callback", mock) + cli_runner.invoke(cli, ["-vvvv"]) + # 4 is the number -v called earlier + mock.assert_called_once_with(ANY, ANY, 4) diff --git a/tests/aiet/test_cli_application.py b/tests/aiet/test_cli_application.py new file mode 100644 index 0000000..f1ccc44 --- /dev/null +++ b/tests/aiet/test_cli_application.py @@ -0,0 +1,1153 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=attribute-defined-outside-init,no-member,line-too-long,too-many-arguments,too-many-locals,redefined-outer-name,too-many-lines +"""Module for testing CLI application subcommand.""" +import base64 +import json +import re +import time +from contextlib import contextmanager +from contextlib import ExitStack +from pathlib import Path +from typing import Any +from typing import Generator +from typing import IO +from typing import List +from typing import Optional +from typing import TypedDict +from unittest.mock import MagicMock + +import click +import pytest +from click.testing import CliRunner +from filelock import FileLock + +from aiet.backend.application import Application +from aiet.backend.config import ApplicationConfig +from aiet.backend.config import LocalProtocolConfig +from aiet.backend.config import SSHConfig +from aiet.backend.config import SystemConfig +from aiet.backend.config import UserParamConfig +from aiet.backend.output_parser import Base64OutputParser +from aiet.backend.protocol import SSHProtocol +from aiet.backend.system import load_system +from aiet.cli.application import application_cmd +from aiet.cli.application import details_cmd +from aiet.cli.application import execute_cmd +from aiet.cli.application import install_cmd +from aiet.cli.application import list_cmd +from aiet.cli.application import parse_payload_run_config +from aiet.cli.application import remove_cmd +from aiet.cli.application import run_cmd +from aiet.cli.common import MiddlewareExitCode + + +def test_application_cmd() -> None: + """Test application commands.""" + commands = ["list", "details", "install", "remove", "execute", "run"] + assert all(command in application_cmd.commands for command in commands) + + +@pytest.mark.parametrize("format_", ["json", "cli"]) +def test_application_cmd_context(cli_runner: CliRunner, format_: str) -> None: + """Test setting command context parameters.""" + result = cli_runner.invoke(application_cmd, ["--format", format_]) + # command should fail if no subcommand provided + assert result.exit_code == 2 + + result = cli_runner.invoke(application_cmd, ["--format", format_, "list"]) + assert result.exit_code == 0 + + +@pytest.mark.parametrize( + "format_, system_name, expected_output", + [ + ( + "json", + None, + '{"type": "application", "available": ["application_1", "application_2"]}\n', + ), + ( + "json", + "system_1", + '{"type": "application", "available": ["application_1"]}\n', + ), + ("cli", None, "Available applications:\n\napplication_1\napplication_2\n"), + ("cli", "system_1", "Available applications:\n\napplication_1\n"), + ], +) +def test_list_cmd( + cli_runner: CliRunner, + monkeypatch: Any, + format_: str, + system_name: str, + expected_output: str, +) -> None: + """Test available applications commands.""" + # Mock some applications + mock_application_1 = MagicMock(spec=Application) + mock_application_1.name = "application_1" + mock_application_1.can_run_on.return_value = system_name == "system_1" + mock_application_2 = MagicMock(spec=Application) + mock_application_2.name = "application_2" + mock_application_2.can_run_on.return_value = system_name == "system_2" + + # Monkey patch the call get_available_applications + mock_available_applications = MagicMock() + mock_available_applications.return_value = [mock_application_1, mock_application_2] + + monkeypatch.setattr( + "aiet.backend.application.get_available_applications", + mock_available_applications, + ) + + obj = {"format": format_} + args = [] + if system_name: + list_cmd.params[0].type = click.Choice([system_name]) + args = ["--system", system_name] + result = cli_runner.invoke(list_cmd, obj=obj, args=args) + assert result.output == expected_output + + +def get_test_application() -> Application: + """Return test system details.""" + config = ApplicationConfig( + name="application", + description="test", + build_dir="", + supported_systems=[], + deploy_data=[], + user_params={}, + commands={ + "clean": ["clean"], + "build": ["build"], + "run": ["run"], + "post_run": ["post_run"], + }, + ) + + return Application(config) + + +def get_details_cmd_json_output() -> str: + """Get JSON output for details command.""" + json_output = """ +[ + { + "type": "application", + "name": "application", + "description": "test", + "supported_systems": [], + "commands": { + "clean": { + "command_strings": [ + "clean" + ], + "user_params": [] + }, + "build": { + "command_strings": [ + "build" + ], + "user_params": [] + }, + "run": { + "command_strings": [ + "run" + ], + "user_params": [] + }, + "post_run": { + "command_strings": [ + "post_run" + ], + "user_params": [] + } + } + } +]""" + return json.dumps(json.loads(json_output)) + "\n" + + +def get_details_cmd_console_output() -> str: + """Get console output for details command.""" + return ( + 'Application "application" details' + + "\nDescription: test" + + "\n\nSupported systems: " + + "\n\nclean commands:" + + "\nCommands: ['clean']" + + "\n\nbuild commands:" + + "\nCommands: ['build']" + + "\n\nrun commands:" + + "\nCommands: ['run']" + + "\n\npost_run commands:" + + "\nCommands: ['post_run']" + + "\n" + ) + + +@pytest.mark.parametrize( + "application_name,format_, expected_output", + [ + ("application", "json", get_details_cmd_json_output()), + ("application", "cli", get_details_cmd_console_output()), + ], +) +def test_details_cmd( + cli_runner: CliRunner, + monkeypatch: Any, + application_name: str, + format_: str, + expected_output: str, +) -> None: + """Test application details command.""" + monkeypatch.setattr( + "aiet.cli.application.get_application", + MagicMock(return_value=[get_test_application()]), + ) + + details_cmd.params[0].type = click.Choice(["application"]) + result = cli_runner.invoke( + details_cmd, obj={"format": format_}, args=["--name", application_name] + ) + assert result.exception is None + assert result.output == expected_output + + +def test_details_cmd_wrong_system(cli_runner: CliRunner, monkeypatch: Any) -> None: + """Test details command fails if application is not supported by the system.""" + monkeypatch.setattr( + "aiet.backend.execution.get_application", MagicMock(return_value=[]) + ) + + details_cmd.params[0].type = click.Choice(["application"]) + details_cmd.params[1].type = click.Choice(["system"]) + result = cli_runner.invoke( + details_cmd, args=["--name", "application", "--system", "system"] + ) + assert result.exit_code == 2 + assert ( + "Application 'application' doesn't support the system 'system'" in result.stdout + ) + + +def test_install_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None: + """Test install application command.""" + mock_install_application = MagicMock() + monkeypatch.setattr( + "aiet.cli.application.install_application", mock_install_application + ) + + args = ["--source", "test"] + cli_runner.invoke(install_cmd, args=args) + mock_install_application.assert_called_once_with(Path("test")) + + +def test_remove_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None: + """Test remove application command.""" + mock_remove_application = MagicMock() + monkeypatch.setattr( + "aiet.cli.application.remove_application", mock_remove_application + ) + remove_cmd.params[0].type = click.Choice(["test"]) + + args = ["--directory_name", "test"] + cli_runner.invoke(remove_cmd, args=args) + mock_remove_application.assert_called_once_with("test") + + +class ExecutionCase(TypedDict, total=False): + """Execution case.""" + + args: List[str] + lock_path: str + can_establish_connection: bool + establish_connection_delay: int + app_exit_code: int + exit_code: int + output: str + + +@pytest.mark.parametrize( + "application_config, system_config, executions", + [ + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + config_location=Path("wrong_location"), + commands={"build": ["echo build {application.name}"]}, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + config_location=Path("wrong_location"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=["-c", "build"], + exit_code=MiddlewareExitCode.CONFIGURATION_ERROR, + output="Error: Application test_application has wrong config location\n", + ) + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + build_dir="build", + deploy_data=[("sample_file", "/tmp/sample_file")], + commands={"build": ["echo build {application.name}"]}, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=["-c", "run"], + exit_code=MiddlewareExitCode.CONFIGURATION_ERROR, + output="Error: System test_system does not support data deploy\n", + ) + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + commands={"build": ["echo build {application.name}"]}, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=["-c", "build"], + exit_code=MiddlewareExitCode.CONFIGURATION_ERROR, + output="Error: No build directory defined for the app test_application\n", + ) + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["new_system"], + build_dir="build", + commands={ + "build": ["echo build {application.name} with {user_params:0}"] + }, + user_params={ + "build": [ + UserParamConfig( + name="param", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + ) + ] + }, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=["-c", "build"], + exit_code=1, + output="Error: Application 'test_application' doesn't support the system 'test_system'\n", + ) + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + build_dir="build", + commands={"build": ["false"]}, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=["-c", "build"], + exit_code=MiddlewareExitCode.BACKEND_ERROR, + output="""Running: false +Error: Execution failed. Please check output for the details.\n""", + ) + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + lock=True, + build_dir="build", + commands={ + "build": ["echo build {application.name} with {user_params:0}"] + }, + user_params={ + "build": [ + UserParamConfig( + name="param", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + ) + ] + }, + ), + SystemConfig( + name="test_system", + description="Test system", + lock=True, + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=["-c", "build"], + exit_code=MiddlewareExitCode.SUCCESS, + output="""Running: echo build test_application with param default +build test_application with param default\n""", + ), + ExecutionCase( + args=["-c", "build"], + lock_path="/tmp/middleware_test_application_test_system.lock", + exit_code=MiddlewareExitCode.CONCURRENT_ERROR, + output="Error: Another instance of the system is running\n", + ), + ExecutionCase( + args=["-c", "build", "--param=param=val3"], + exit_code=MiddlewareExitCode.SUCCESS, + output="""Running: echo build test_application with param val3 +build test_application with param val3\n""", + ), + ExecutionCase( + args=["-c", "build", "--param=param=newval"], + exit_code=1, + output="Error: Application parameter 'param=newval' not valid for command 'build'\n", + ), + ExecutionCase( + args=["-c", "some_command"], + exit_code=MiddlewareExitCode.CONFIGURATION_ERROR, + output="Error: Unsupported command some_command\n", + ), + ExecutionCase( + args=["-c", "run"], + exit_code=MiddlewareExitCode.SUCCESS, + output="""Generating commands to execute +Running: echo run test_application on test_system +run test_application on test_system\n""", + ), + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + deploy_data=[("sample_file", "/tmp/sample_file")], + commands={ + "run": [ + "echo run {application.name} with {user_params:param} on {system.name}" + ] + }, + user_params={ + "run": [ + UserParamConfig( + name="param=", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + alias="param", + ) + ] + }, + ), + SystemConfig( + name="test_system", + description="Test system", + lock=True, + data_transfer=SSHConfig( + protocol="ssh", + username="username", + password="password", + hostname="localhost", + port="8022", + ), + commands={"run": ["sleep 100"]}, + ), + [ + ExecutionCase( + args=["-c", "run"], + exit_code=MiddlewareExitCode.SUCCESS, + output="""Generating commands to execute +Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . +Deploying {application.config_location}/sample_file onto /tmp/sample_file +Running: echo run test_application with param=default on test_system +Shutting down sequence... +Stopping test_system... (It could take few seconds) +test_system stopped successfully.\n""", + ), + ExecutionCase( + args=["-c", "run"], + lock_path="/tmp/middleware_test_system.lock", + exit_code=MiddlewareExitCode.CONCURRENT_ERROR, + output="Error: Another instance of the system is running\n", + ), + ExecutionCase( + args=[ + "-c", + "run", + "--deploy={application.config_location}/sample_file:/tmp/sample_file", + ], + exit_code=0, + output="""Generating commands to execute +Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . +Deploying {application.config_location}/sample_file onto /tmp/sample_file +Deploying {application.config_location}/sample_file onto /tmp/sample_file +Running: echo run test_application with param=default on test_system +Shutting down sequence... +Stopping test_system... (It could take few seconds) +test_system stopped successfully.\n""", + ), + ExecutionCase( + args=["-c", "run"], + app_exit_code=1, + exit_code=0, + output="""Generating commands to execute +Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . +Deploying {application.config_location}/sample_file onto /tmp/sample_file +Running: echo run test_application with param=default on test_system +Application exited with exit code 1 +Shutting down sequence... +Stopping test_system... (It could take few seconds) +test_system stopped successfully.\n""", + ), + ExecutionCase( + args=["-c", "run"], + exit_code=MiddlewareExitCode.CONNECTION_ERROR, + can_establish_connection=False, + output="""Generating commands to execute +Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .......................................................................................... +Shutting down sequence... +Stopping test_system... (It could take few seconds) +test_system stopped successfully. +Error: Couldn't connect to 'localhost:8022'.\n""", + ), + ExecutionCase( + args=["-c", "run", "--deploy=bad_format"], + exit_code=1, + output="Error: Invalid deploy parameter 'bad_format' for command run\n", + ), + ExecutionCase( + args=["-c", "run", "--deploy=:"], + exit_code=1, + output="Error: Invalid deploy parameter ':' for command run\n", + ), + ExecutionCase( + args=["-c", "run", "--deploy= : "], + exit_code=1, + output="Error: Invalid deploy parameter ' : ' for command run\n", + ), + ExecutionCase( + args=["-c", "run", "--deploy=some_src_file:"], + exit_code=1, + output="Error: Invalid deploy parameter 'some_src_file:' for command run\n", + ), + ExecutionCase( + args=["-c", "run", "--deploy=:some_dst_file"], + exit_code=1, + output="Error: Invalid deploy parameter ':some_dst_file' for command run\n", + ), + ExecutionCase( + args=["-c", "run", "--deploy=unknown_file:/tmp/dest"], + exit_code=1, + output="Error: Path unknown_file does not exist\n", + ), + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + commands={ + "run": [ + "echo run {application.name} with {user_params:param} on {system.name}" + ] + }, + user_params={ + "run": [ + UserParamConfig( + name="param=", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + alias="param", + ) + ] + }, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=SSHConfig( + protocol="ssh", + username="username", + password="password", + hostname="localhost", + port="8022", + ), + commands={"run": ["echo Unable to start system"]}, + ), + [ + ExecutionCase( + args=["-c", "run"], + exit_code=4, + can_establish_connection=False, + establish_connection_delay=1, + output="""Generating commands to execute +Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . + +---------- test_system execution failed ---------- +Unable to start system + + + +Shutting down sequence... +Stopping test_system... (It could take few seconds) +test_system stopped successfully. +Error: Execution failed. Please check output for the details.\n""", + ) + ], + ], + ], +) +def test_application_command_execution( + application_config: ApplicationConfig, + system_config: SystemConfig, + executions: List[ExecutionCase], + tmpdir: Any, + cli_runner: CliRunner, + monkeypatch: Any, +) -> None: + """Test application command execution.""" + + @contextmanager + def lock_execution(lock_path: str) -> Generator[None, None, None]: + lock = FileLock(lock_path) + lock.acquire(timeout=1) + + try: + yield + finally: + lock.release() + + def replace_vars(str_val: str) -> str: + """Replace variables.""" + application_config_location = str( + application_config["config_location"].absolute() + ) + + return str_val.replace( + "{application.config_location}", application_config_location + ) + + for execution in executions: + init_execution_test( + monkeypatch, + tmpdir, + application_config, + system_config, + can_establish_connection=execution.get("can_establish_connection", True), + establish_conection_delay=execution.get("establish_connection_delay", 0), + remote_app_exit_code=execution.get("app_exit_code", 0), + ) + + lock_path = execution.get("lock_path") + + with ExitStack() as stack: + if lock_path: + stack.enter_context(lock_execution(lock_path)) + + args = [replace_vars(arg) for arg in execution["args"]] + + result = cli_runner.invoke( + execute_cmd, + args=["-n", application_config["name"], "-s", system_config["name"]] + + args, + ) + output = replace_vars(execution["output"]) + assert result.exit_code == execution["exit_code"] + assert result.stdout == output + + +@pytest.fixture(params=[False, True], ids=["run-cli", "run-json"]) +def payload_path_or_none(request: Any, tmp_path_factory: Any) -> Optional[Path]: + """Drives tests for run command so that it executes them both to use a json file, and to use CLI.""" + if request.param: + ret: Path = tmp_path_factory.getbasetemp() / "system_config_payload_file.json" + return ret + return None + + +def write_system_payload_config( + payload_file: IO[str], + application_config: ApplicationConfig, + system_config: SystemConfig, +) -> None: + """Write a json payload file for the given test configuration.""" + payload_dict = { + "id": system_config["name"], + "arguments": { + "application": application_config["name"], + }, + } + json.dump(payload_dict, payload_file) + + +@pytest.mark.parametrize( + "application_config, system_config, executions", + [ + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + build_dir="build", + commands={ + "build": ["echo build {application.name} with {user_params:0}"] + }, + user_params={ + "build": [ + UserParamConfig( + name="param", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + ) + ] + }, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=[], + exit_code=MiddlewareExitCode.SUCCESS, + output="""Running: echo build test_application with param default +build test_application with param default +Generating commands to execute +Running: echo run test_application on test_system +run test_application on test_system\n""", + ) + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + commands={ + "run": [ + "echo run {application.name} with {user_params:param} on {system.name}" + ] + }, + user_params={ + "run": [ + UserParamConfig( + name="param=", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + alias="param", + ) + ] + }, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=SSHConfig( + protocol="ssh", + username="username", + password="password", + hostname="localhost", + port="8022", + ), + commands={"run": ["sleep 100"]}, + ), + [ + ExecutionCase( + args=[], + exit_code=MiddlewareExitCode.SUCCESS, + output="""Generating commands to execute +Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . +Running: echo run test_application with param=default on test_system +Shutting down sequence... +Stopping test_system... (It could take few seconds) +test_system stopped successfully.\n""", + ) + ], + ], + ], +) +def test_application_run( + application_config: ApplicationConfig, + system_config: SystemConfig, + executions: List[ExecutionCase], + tmpdir: Any, + cli_runner: CliRunner, + monkeypatch: Any, + payload_path_or_none: Path, +) -> None: + """Test application command execution.""" + for execution in executions: + init_execution_test(monkeypatch, tmpdir, application_config, system_config) + + if payload_path_or_none: + with open(payload_path_or_none, "w", encoding="utf-8") as payload_file: + write_system_payload_config( + payload_file, application_config, system_config + ) + + result = cli_runner.invoke( + run_cmd, + args=["--config", str(payload_path_or_none)], + ) + else: + result = cli_runner.invoke( + run_cmd, + args=["-n", application_config["name"], "-s", system_config["name"]] + + execution["args"], + ) + + assert result.stdout == execution["output"] + assert result.exit_code == execution["exit_code"] + + +@pytest.mark.parametrize( + "cmdline,error_pattern", + [ + [ + "--config {payload} -s test_system", + "when --config is set, the following parameters should not be provided", + ], + [ + "--config {payload} -n test_application", + "when --config is set, the following parameters should not be provided", + ], + [ + "--config {payload} -p mypar:3", + "when --config is set, the following parameters should not be provided", + ], + [ + "-p mypar:3", + "when --config is not set, the following parameters are required", + ], + ["-s test_system", "when --config is not set, --name is required"], + ["-n test_application", "when --config is not set, --system is required"], + ], +) +def test_application_run_invalid_param_combinations( + cmdline: str, + error_pattern: str, + cli_runner: CliRunner, + monkeypatch: Any, + tmp_path: Any, + tmpdir: Any, +) -> None: + """Test that invalid combinations arguments result in error as expected.""" + application_config = ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + build_dir="build", + commands={"build": ["echo build {application.name} with {user_params:0}"]}, + user_params={ + "build": [ + UserParamConfig( + name="param", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + ) + ] + }, + ) + system_config = SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ) + + init_execution_test(monkeypatch, tmpdir, application_config, system_config) + + payload_file = tmp_path / "payload.json" + payload_file.write_text("dummy") + result = cli_runner.invoke( + run_cmd, + args=cmdline.format(payload=payload_file).split(), + ) + found = re.search(error_pattern, result.stdout) + assert found, f"Cannot find pattern: [{error_pattern}] in \n[\n{result.stdout}\n]" + + +@pytest.mark.parametrize( + "payload,expected", + [ + pytest.param( + {"arguments": {}}, + None, + marks=pytest.mark.xfail(reason="no system 'id''", strict=True), + ), + pytest.param( + {"id": "testsystem"}, + None, + marks=pytest.mark.xfail(reason="no arguments object", strict=True), + ), + ( + {"id": "testsystem", "arguments": {"application": "testapp"}}, + ("testsystem", "testapp", [], [], [], None), + ), + ( + { + "id": "testsystem", + "arguments": {"application": "testapp", "par1": "val1"}, + }, + ("testsystem", "testapp", ["par1=val1"], [], [], None), + ), + ( + { + "id": "testsystem", + "arguments": {"application": "testapp", "application/par1": "val1"}, + }, + ("testsystem", "testapp", ["par1=val1"], [], [], None), + ), + ( + { + "id": "testsystem", + "arguments": {"application": "testapp", "system/par1": "val1"}, + }, + ("testsystem", "testapp", [], ["par1=val1"], [], None), + ), + ( + { + "id": "testsystem", + "arguments": {"application": "testapp", "deploy/par1": "val1"}, + }, + ("testsystem", "testapp", [], [], ["par1"], None), + ), + ( + { + "id": "testsystem", + "arguments": { + "application": "testapp", + "appar1": "val1", + "application/appar2": "val2", + "system/syspar1": "val3", + "deploy/depploypar1": "val4", + "application/appar3": "val5", + "system/syspar2": "val6", + "deploy/depploypar2": "val7", + }, + }, + ( + "testsystem", + "testapp", + ["appar1=val1", "appar2=val2", "appar3=val5"], + ["syspar1=val3", "syspar2=val6"], + ["depploypar1", "depploypar2"], + None, + ), + ), + ], +) +def test_parse_payload_run_config(payload: dict, expected: tuple) -> None: + """Test parsing of the JSON payload for the run_config command.""" + assert parse_payload_run_config(payload) == expected + + +def test_application_run_report( + tmpdir: Any, + cli_runner: CliRunner, + monkeypatch: Any, +) -> None: + """Test flag '--report' of command 'application run'.""" + app_metrics = {"app_metric": 3.14} + app_metrics_b64 = base64.b64encode(json.dumps(app_metrics).encode("utf-8")) + application_config = ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + build_dir="build", + commands={"build": ["echo build {application.name} with {user_params:0}"]}, + user_params={ + "build": [ + UserParamConfig( + name="param", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + ), + UserParamConfig( + name="p2", + description="another parameter, not overridden", + default_value="the-right-choice", + values=["the-right-choice", "the-bad-choice"], + ), + ] + }, + ) + system_config = SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={ + "run": [ + "echo run {application.name} on {system.name}", + f"echo build <{Base64OutputParser.TAG_NAME}>{app_metrics_b64.decode('utf-8')}", + ] + }, + reporting={ + "regex": { + "app_name": { + "pattern": r"run (.\S*) ", + "type": "str", + }, + "sys_name": { + "pattern": r"on (.\S*)", + "type": "str", + }, + } + }, + ) + report_file = Path(tmpdir) / "test_report.json" + param_val = "param=val1" + exit_code = MiddlewareExitCode.SUCCESS + + init_execution_test(monkeypatch, tmpdir, application_config, system_config) + + result = cli_runner.invoke( + run_cmd, + args=[ + "-n", + application_config["name"], + "-s", + system_config["name"], + "--report", + str(report_file), + "--param", + param_val, + ], + ) + assert result.exit_code == exit_code + assert report_file.is_file() + with open(report_file, "r", encoding="utf-8") as file: + report = json.load(file) + + assert report == { + "application": { + "metrics": {"0": {"app_metric": 3.14}}, + "name": "test_application", + "params": {"param": "val1", "p2": "the-right-choice"}, + }, + "system": { + "metrics": {"app_name": "test_application", "sys_name": "test_system"}, + "name": "test_system", + "params": {}, + }, + } + + +def init_execution_test( + monkeypatch: Any, + tmpdir: Any, + application_config: ApplicationConfig, + system_config: SystemConfig, + can_establish_connection: bool = True, + establish_conection_delay: float = 0, + remote_app_exit_code: int = 0, +) -> None: + """Init execution test.""" + application_name = application_config["name"] + system_name = system_config["name"] + + execute_cmd.params[0].type = click.Choice([application_name]) + execute_cmd.params[1].type = click.Choice([system_name]) + execute_cmd.params[2].type = click.Choice(["build", "run", "some_command"]) + + run_cmd.params[0].type = click.Choice([application_name]) + run_cmd.params[1].type = click.Choice([system_name]) + + if "config_location" not in application_config: + application_path = Path(tmpdir) / "application" + application_path.mkdir() + application_config["config_location"] = application_path + + # this file could be used as deploy parameter value or + # as deploy parameter in application configuration + sample_file = application_path / "sample_file" + sample_file.touch() + monkeypatch.setattr( + "aiet.backend.application.get_available_applications", + MagicMock(return_value=[Application(application_config)]), + ) + + ssh_protocol_mock = MagicMock(spec=SSHProtocol) + + def mock_establish_connection() -> bool: + """Mock establish connection function.""" + # give some time for the system to start + time.sleep(establish_conection_delay) + return can_establish_connection + + ssh_protocol_mock.establish_connection.side_effect = mock_establish_connection + ssh_protocol_mock.connection_details.return_value = ("localhost", 8022) + ssh_protocol_mock.run.return_value = ( + remote_app_exit_code, + bytearray(), + bytearray(), + ) + monkeypatch.setattr( + "aiet.backend.protocol.SSHProtocol", MagicMock(return_value=ssh_protocol_mock) + ) + + if "config_location" not in system_config: + system_path = Path(tmpdir) / "system" + system_path.mkdir() + system_config["config_location"] = system_path + monkeypatch.setattr( + "aiet.backend.system.get_available_systems", + MagicMock(return_value=[load_system(system_config)]), + ) + + monkeypatch.setattr("aiet.backend.execution.wait", MagicMock()) diff --git a/tests/aiet/test_cli_common.py b/tests/aiet/test_cli_common.py new file mode 100644 index 0000000..d018e44 --- /dev/null +++ b/tests/aiet/test_cli_common.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Test for cli common module.""" +from typing import Any + +import pytest + +from aiet.cli.common import print_command_details +from aiet.cli.common import raise_exception_at_signal + + +def test_print_command_details(capsys: Any) -> None: + """Test print_command_details function.""" + command = { + "command_strings": ["echo test"], + "user_params": [ + {"name": "param_name", "description": "param_description"}, + { + "name": "param_name2", + "description": "param_description2", + "alias": "alias2", + }, + ], + } + print_command_details(command) + captured = capsys.readouterr() + assert "echo test" in captured.out + assert "param_name" in captured.out + assert "alias2" in captured.out + + +def test_raise_exception_at_signal() -> None: + """Test raise_exception_at_signal graceful shutdown.""" + with pytest.raises(Exception) as err: + raise_exception_at_signal(1, "") + + assert str(err.value) == "Middleware shutdown requested" diff --git a/tests/aiet/test_cli_system.py b/tests/aiet/test_cli_system.py new file mode 100644 index 0000000..fd39f31 --- /dev/null +++ b/tests/aiet/test_cli_system.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for testing CLI system subcommand.""" +import json +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Union +from unittest.mock import MagicMock + +import click +import pytest +from click.testing import CliRunner + +from aiet.backend.config import SystemConfig +from aiet.backend.system import load_system +from aiet.backend.system import System +from aiet.cli.system import details_cmd +from aiet.cli.system import install_cmd +from aiet.cli.system import list_cmd +from aiet.cli.system import remove_cmd +from aiet.cli.system import system_cmd + + +def test_system_cmd() -> None: + """Test system commands.""" + commands = ["list", "details", "install", "remove"] + assert all(command in system_cmd.commands for command in commands) + + +@pytest.mark.parametrize("format_", ["json", "cli"]) +def test_system_cmd_context(cli_runner: CliRunner, format_: str) -> None: + """Test setting command context parameters.""" + result = cli_runner.invoke(system_cmd, ["--format", format_]) + # command should fail if no subcommand provided + assert result.exit_code == 2 + + result = cli_runner.invoke(system_cmd, ["--format", format_, "list"]) + assert result.exit_code == 0 + + +@pytest.mark.parametrize( + "format_,expected_output", + [ + ("json", '{"type": "system", "available": ["system1", "system2"]}\n'), + ("cli", "Available systems:\n\nsystem1\nsystem2\n"), + ], +) +def test_list_cmd_with_format( + cli_runner: CliRunner, monkeypatch: Any, format_: str, expected_output: str +) -> None: + """Test available systems command with different formats output.""" + # Mock some systems + mock_system1 = MagicMock() + mock_system1.name = "system1" + mock_system2 = MagicMock() + mock_system2.name = "system2" + + # Monkey patch the call get_available_systems + mock_available_systems = MagicMock() + mock_available_systems.return_value = [mock_system1, mock_system2] + monkeypatch.setattr("aiet.cli.system.get_available_systems", mock_available_systems) + + obj = {"format": format_} + result = cli_runner.invoke(list_cmd, obj=obj) + assert result.output == expected_output + + +def get_test_system( + annotations: Optional[Dict[str, Union[str, List[str]]]] = None +) -> System: + """Return test system details.""" + config = SystemConfig( + name="system", + description="test", + data_transfer={ + "protocol": "ssh", + "username": "root", + "password": "root", + "hostname": "localhost", + "port": "8022", + }, + commands={ + "clean": ["clean"], + "build": ["build"], + "run": ["run"], + "post_run": ["post_run"], + }, + annotations=annotations or {}, + ) + + return load_system(config) + + +def get_details_cmd_json_output( + annotations: Optional[Dict[str, Union[str, List[str]]]] = None +) -> str: + """Test JSON output for details command.""" + ann_str = "" + if annotations is not None: + ann_str = '"annotations":{},'.format(json.dumps(annotations)) + + json_output = ( + """ +{ + "type": "system", + "name": "system", + "description": "test", + "data_transfer_protocol": "ssh", + "commands": { + "clean": + { + "command_strings": ["clean"], + "user_params": [] + }, + "build": + { + "command_strings": ["build"], + "user_params": [] + }, + "run": + { + "command_strings": ["run"], + "user_params": [] + }, + "post_run": + { + "command_strings": ["post_run"], + "user_params": [] + } + }, +""" + + ann_str + + """ + "available_application" : [] + } +""" + ) + return json.dumps(json.loads(json_output)) + "\n" + + +def get_details_cmd_console_output( + annotations: Optional[Dict[str, Union[str, List[str]]]] = None +) -> str: + """Test console output for details command.""" + ann_str = "" + if annotations: + val_str = "".join( + "\n\t{}: {}".format(ann_name, ann_value) + for ann_name, ann_value in annotations.items() + ) + ann_str = "\nAnnotations:{}".format(val_str) + return ( + 'System "system" details' + + "\nDescription: test" + + "\nData Transfer Protocol: ssh" + + "\nAvailable Applications: " + + ann_str + + "\n\nclean commands:" + + "\nCommands: ['clean']" + + "\n\nbuild commands:" + + "\nCommands: ['build']" + + "\n\nrun commands:" + + "\nCommands: ['run']" + + "\n\npost_run commands:" + + "\nCommands: ['post_run']" + + "\n" + ) + + +@pytest.mark.parametrize( + "format_,system,expected_output", + [ + ( + "json", + get_test_system(annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}), + get_details_cmd_json_output( + annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]} + ), + ), + ( + "cli", + get_test_system(annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}), + get_details_cmd_console_output( + annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]} + ), + ), + ( + "json", + get_test_system(annotations={}), + get_details_cmd_json_output(annotations={}), + ), + ( + "cli", + get_test_system(annotations={}), + get_details_cmd_console_output(annotations={}), + ), + ], +) +def test_details_cmd( + cli_runner: CliRunner, + monkeypatch: Any, + format_: str, + system: System, + expected_output: str, +) -> None: + """Test details command with different formats output.""" + mock_get_system = MagicMock() + mock_get_system.return_value = system + monkeypatch.setattr("aiet.cli.system.get_system", mock_get_system) + + args = ["--name", "system"] + obj = {"format": format_} + details_cmd.params[0].type = click.Choice(["system"]) + + result = cli_runner.invoke(details_cmd, args=args, obj=obj) + assert result.output == expected_output + + +def test_install_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None: + """Test install system command.""" + mock_install_system = MagicMock() + monkeypatch.setattr("aiet.cli.system.install_system", mock_install_system) + + args = ["--source", "test"] + cli_runner.invoke(install_cmd, args=args) + mock_install_system.assert_called_once_with(Path("test")) + + +def test_remove_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None: + """Test remove system command.""" + mock_remove_system = MagicMock() + monkeypatch.setattr("aiet.cli.system.remove_system", mock_remove_system) + remove_cmd.params[0].type = click.Choice(["test"]) + + args = ["--directory_name", "test"] + cli_runner.invoke(remove_cmd, args=args) + mock_remove_system.assert_called_once_with("test") diff --git a/tests/aiet/test_cli_tool.py b/tests/aiet/test_cli_tool.py new file mode 100644 index 0000000..45d45c8 --- /dev/null +++ b/tests/aiet/test_cli_tool.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=attribute-defined-outside-init,no-member,line-too-long,too-many-arguments,too-many-locals +"""Module for testing CLI tool subcommand.""" +import json +from pathlib import Path +from typing import Any +from typing import List +from typing import Optional +from typing import Sequence +from unittest.mock import MagicMock + +import click +import pytest +from click.testing import CliRunner +from click.testing import Result + +from aiet.backend.tool import get_unique_tool_names +from aiet.backend.tool import Tool +from aiet.cli.tool import details_cmd +from aiet.cli.tool import execute_cmd +from aiet.cli.tool import list_cmd +from aiet.cli.tool import tool_cmd + + +def test_tool_cmd() -> None: + """Test tool commands.""" + commands = ["list", "details", "execute"] + assert all(command in tool_cmd.commands for command in commands) + + +@pytest.mark.parametrize("format_", ["json", "cli"]) +def test_tool_cmd_context(cli_runner: CliRunner, format_: str) -> None: + """Test setting command context parameters.""" + result = cli_runner.invoke(tool_cmd, ["--format", format_]) + # command should fail if no subcommand provided + assert result.exit_code == 2 + + result = cli_runner.invoke(tool_cmd, ["--format", format_, "list"]) + assert result.exit_code == 0 + + +@pytest.mark.parametrize( + "format_, expected_output", + [ + ( + "json", + '{"type": "tool", "available": ["tool_1", "tool_2"]}\n', + ), + ("cli", "Available tools:\n\ntool_1\ntool_2\n"), + ], +) +def test_list_cmd( + cli_runner: CliRunner, + monkeypatch: Any, + format_: str, + expected_output: str, +) -> None: + """Test available tool commands.""" + # Mock some tools + mock_tool_1 = MagicMock(spec=Tool) + mock_tool_1.name = "tool_1" + mock_tool_2 = MagicMock(spec=Tool) + mock_tool_2.name = "tool_2" + + # Monkey patch the call get_available_tools + mock_available_tools = MagicMock() + mock_available_tools.return_value = [mock_tool_1, mock_tool_2] + + monkeypatch.setattr("aiet.backend.tool.get_available_tools", mock_available_tools) + + obj = {"format": format_} + args: Sequence[str] = [] + result = cli_runner.invoke(list_cmd, obj=obj, args=args) + assert result.output == expected_output + + +def get_details_cmd_json_output() -> List[dict]: + """Get JSON output for details command.""" + json_output = [ + { + "type": "tool", + "name": "tool_1", + "description": "This is tool 1", + "supported_systems": ["System 1"], + "commands": { + "clean": {"command_strings": ["echo 'clean'"], "user_params": []}, + "build": {"command_strings": ["echo 'build'"], "user_params": []}, + "run": {"command_strings": ["echo 'run'"], "user_params": []}, + "post_run": {"command_strings": ["echo 'post_run'"], "user_params": []}, + }, + } + ] + + return json_output + + +def get_details_cmd_console_output() -> str: + """Get console output for details command.""" + return ( + 'Tool "tool_1" details' + "\nDescription: This is tool 1" + "\n\nSupported systems: System 1" + "\n\nclean commands:" + "\nCommands: [\"echo 'clean'\"]" + "\n\nbuild commands:" + "\nCommands: [\"echo 'build'\"]" + "\n\nrun commands:\nCommands: [\"echo 'run'\"]" + "\n\npost_run commands:" + "\nCommands: [\"echo 'post_run'\"]" + "\n" + ) + + +@pytest.mark.parametrize( + [ + "tool_name", + "format_", + "expected_success", + "expected_output", + ], + [ + ("tool_1", "json", True, get_details_cmd_json_output()), + ("tool_1", "cli", True, get_details_cmd_console_output()), + ("non-existent tool", "json", False, None), + ("non-existent tool", "cli", False, None), + ], +) +def test_details_cmd( + cli_runner: CliRunner, + tool_name: str, + format_: str, + expected_success: bool, + expected_output: str, +) -> None: + """Test tool details command.""" + details_cmd.params[0].type = click.Choice(["tool_1", "tool_2", "vela"]) + result = cli_runner.invoke( + details_cmd, obj={"format": format_}, args=["--name", tool_name] + ) + success = result.exit_code == 0 + assert success == expected_success, result.output + if expected_success: + assert result.exception is None + output = json.loads(result.output) if format_ == "json" else result.output + assert output == expected_output + + +@pytest.mark.parametrize( + "system_name", + [ + "", + "Corstone-300: Cortex-M55+Ethos-U55", + "Corstone-300: Cortex-M55+Ethos-U65", + "Corstone-310: Cortex-M85+Ethos-U55", + ], +) +def test_details_cmd_vela(cli_runner: CliRunner, system_name: str) -> None: + """Test tool details command for Vela.""" + details_cmd.params[0].type = click.Choice(get_unique_tool_names()) + details_cmd.params[1].type = click.Choice([system_name]) + args = ["--name", "vela"] + if system_name: + args += ["--system", system_name] + result = cli_runner.invoke(details_cmd, obj={"format": "json"}, args=args) + success = result.exit_code == 0 + assert success, result.output + result_json = json.loads(result.output) + assert result_json + if system_name: + assert len(result_json) == 1 + tool = result_json[0] + assert len(tool["supported_systems"]) == 1 + assert system_name == tool["supported_systems"][0] + else: # no system specified => list details for all systems + assert len(result_json) == 3 + assert all(len(tool["supported_systems"]) == 1 for tool in result_json) + + +@pytest.fixture(scope="session") +def input_model_file(non_optimised_input_model_file: Path) -> Path: + """Provide the path to a quantized dummy model file in the test_resources_path.""" + return non_optimised_input_model_file + + +def execute_vela( + cli_runner: CliRunner, + tool_name: str = "vela", + system_name: Optional[str] = None, + input_model: Optional[Path] = None, + output_model: Optional[Path] = None, + mac: Optional[int] = None, + format_: str = "cli", +) -> Result: + """Run Vela with different parameters.""" + execute_cmd.params[0].type = click.Choice(get_unique_tool_names()) + execute_cmd.params[2].type = click.Choice([system_name or "dummy_system"]) + args = ["--name", tool_name] + if system_name is not None: + args += ["--system", system_name] + if input_model is not None: + args += ["--param", "input={}".format(input_model)] + if output_model is not None: + args += ["--param", "output={}".format(output_model)] + if mac is not None: + args += ["--param", "mac={}".format(mac)] + result = cli_runner.invoke( + execute_cmd, + args=args, + obj={"format": format_}, + ) + return result + + +@pytest.mark.parametrize("format_", ["cli, json"]) +@pytest.mark.parametrize( + ["tool_name", "system_name", "mac", "expected_success", "expected_output"], + [ + ("vela", "System 1", 32, False, None), # system not supported + ("vela", "NON-EXISTENT SYSTEM", 128, False, None), # system does not exist + ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 32, True, None), + ("NON-EXISTENT TOOL", "Corstone-300: Cortex-M55+Ethos-U55", 32, False, None), + ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 64, True, None), + ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 128, True, None), + ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 256, True, None), + ( + "vela", + "Corstone-300: Cortex-M55+Ethos-U55", + 512, + False, + None, + ), # mac not supported + ( + "vela", + "Corstone-300: Cortex-M55+Ethos-U65", + 32, + False, + None, + ), # mac not supported + ("vela", "Corstone-300: Cortex-M55+Ethos-U65", 256, True, None), + ("vela", "Corstone-300: Cortex-M55+Ethos-U65", 512, True, None), + ( + "vela", + None, + 512, + False, + "Error: Please specify the system for tool vela.", + ), # no system specified + ( + "NON-EXISTENT TOOL", + "Corstone-300: Cortex-M55+Ethos-U65", + 512, + False, + None, + ), # tool does not exist + ("vela", "Corstone-310: Cortex-M85+Ethos-U55", 128, True, None), + ], +) +def test_vela_run( + cli_runner: CliRunner, + format_: str, + input_model_file: Path, # pylint: disable=redefined-outer-name + tool_name: str, + system_name: Optional[str], + mac: int, + expected_success: bool, + expected_output: Optional[str], + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + """Test the execution of the Vela command.""" + monkeypatch.chdir(tmp_path) + + output_file = Path("vela_output.tflite") + + result = execute_vela( + cli_runner, + tool_name=tool_name, + system_name=system_name, + input_model=input_model_file, + output_model=output_file, + mac=mac, + format_=format_, + ) + + success = result.exit_code == 0 + assert success == expected_success + if success: + # Check output file + output_file = output_file.resolve() + assert output_file.is_file() + if expected_output: + assert result.output.strip() == expected_output + + +@pytest.mark.parametrize("include_input_model", [True, False]) +@pytest.mark.parametrize("include_output_model", [True, False]) +@pytest.mark.parametrize("include_mac", [True, False]) +def test_vela_run_missing_params( + cli_runner: CliRunner, + input_model_file: Path, # pylint: disable=redefined-outer-name + include_input_model: bool, + include_output_model: bool, + include_mac: bool, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + """Test the execution of the Vela command with missing user parameters.""" + monkeypatch.chdir(tmp_path) + + output_model_file = Path("output_model.tflite") + system_name = "Corstone-300: Cortex-M55+Ethos-U65" + mac = 256 + # input_model is a required parameters, but mac and output_model have default values. + expected_success = include_input_model + + result = execute_vela( + cli_runner, + tool_name="vela", + system_name=system_name, + input_model=input_model_file if include_input_model else None, + output_model=output_model_file if include_output_model else None, + mac=mac if include_mac else None, + ) + + success = result.exit_code == 0 + assert success == expected_success, ( + f"Success is {success}, but expected {expected_success}. " + f"Included params: [" + f"input_model={include_input_model}, " + f"output_model={include_output_model}, " + f"mac={include_mac}]" + ) diff --git a/tests/aiet/test_main.py b/tests/aiet/test_main.py new file mode 100644 index 0000000..f2ebae2 --- /dev/null +++ b/tests/aiet/test_main.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for testing AIET main.py.""" +from typing import Any +from unittest.mock import MagicMock + +from aiet import main + + +def test_main(monkeypatch: Any) -> None: + """Test main entry point function.""" + with monkeypatch.context() as mock_context: + mock = MagicMock() + mock_context.setattr(main, "cli", mock) + main.main() + mock.assert_called_once() diff --git a/tests/aiet/test_resources/application_config.json b/tests/aiet/test_resources/application_config.json new file mode 100644 index 0000000..2dfcfec --- /dev/null +++ b/tests/aiet/test_resources/application_config.json @@ -0,0 +1,96 @@ +[ + { + "name": "application_1", + "description": "application number one", + "supported_systems": [ + "system_1", + "system_2" + ], + "build_dir": "build_dir_11", + "commands": { + "clean": [ + "clean_cmd_11" + ], + "build": [ + "build_cmd_11" + ], + "run": [ + "run_cmd_11" + ], + "post_run": [ + "post_run_cmd_11" + ] + }, + "user_params": { + "run": [ + { + "name": "run_param_11", + "values": [], + "description": "run param number one" + } + ], + "build": [ + { + "name": "build_param_11", + "values": [], + "description": "build param number one" + }, + { + "name": "build_param_12", + "values": [], + "description": "build param number two" + }, + { + "name": "build_param_13", + "values": [ + "value_1" + ], + "description": "build param number three with some value" + } + ] + } + }, + { + "name": "application_2", + "description": "application number two", + "supported_systems": [ + "system_2" + ], + "build_dir": "build_dir_21", + "commands": { + "clean": [ + "clean_cmd_21" + ], + "build": [ + "build_cmd_21", + "build_cmd_22" + ], + "run": [ + "run_cmd_21" + ], + "post_run": [ + "post_run_cmd_21" + ] + }, + "user_params": { + "build": [ + { + "name": "build_param_21", + "values": [], + "description": "build param number one" + }, + { + "name": "build_param_22", + "values": [], + "description": "build param number two" + }, + { + "name": "build_param_23", + "values": [], + "description": "build param number three" + } + ], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/application_config.json.license b/tests/aiet/test_resources/application_config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/application_config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/application1/aiet-config.json b/tests/aiet/test_resources/applications/application1/aiet-config.json new file mode 100644 index 0000000..97f0401 --- /dev/null +++ b/tests/aiet/test_resources/applications/application1/aiet-config.json @@ -0,0 +1,30 @@ +[ + { + "name": "application_1", + "description": "This is application 1", + "supported_systems": [ + { + "name": "System 1" + } + ], + "build_dir": "build", + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/applications/application1/aiet-config.json.license b/tests/aiet/test_resources/applications/application1/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/applications/application1/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/application2/aiet-config.json b/tests/aiet/test_resources/applications/application2/aiet-config.json new file mode 100644 index 0000000..e9122d3 --- /dev/null +++ b/tests/aiet/test_resources/applications/application2/aiet-config.json @@ -0,0 +1,30 @@ +[ + { + "name": "application_2", + "description": "This is application 2", + "supported_systems": [ + { + "name": "System 2" + } + ], + "build_dir": "build", + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/applications/application2/aiet-config.json.license b/tests/aiet/test_resources/applications/application2/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/applications/application2/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/application3/readme.txt b/tests/aiet/test_resources/applications/application3/readme.txt new file mode 100644 index 0000000..8c72c05 --- /dev/null +++ b/tests/aiet/test_resources/applications/application3/readme.txt @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +This application does not have json configuration file diff --git a/tests/aiet/test_resources/applications/application4/aiet-config.json b/tests/aiet/test_resources/applications/application4/aiet-config.json new file mode 100644 index 0000000..34dc780 --- /dev/null +++ b/tests/aiet/test_resources/applications/application4/aiet-config.json @@ -0,0 +1,35 @@ +[ + { + "name": "application_4", + "description": "This is application 4", + "build_dir": "build", + "supported_systems": [ + { + "name": "System 4" + } + ], + "commands": { + "build": [ + "cp ../hello_app.txt . # {user_params:0}" + ], + "run": [ + "{application.build_dir}/hello_app.txt" + ] + }, + "user_params": { + "build": [ + { + "name": "--app", + "description": "Sample command param", + "values": [ + "application1", + "application2", + "application3" + ], + "default_value": "application1" + } + ], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/applications/application4/aiet-config.json.license b/tests/aiet/test_resources/applications/application4/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/applications/application4/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/application4/hello_app.txt b/tests/aiet/test_resources/applications/application4/hello_app.txt new file mode 100644 index 0000000..2ec0d1d --- /dev/null +++ b/tests/aiet/test_resources/applications/application4/hello_app.txt @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +Hello from APP! diff --git a/tests/aiet/test_resources/applications/application5/aiet-config.json b/tests/aiet/test_resources/applications/application5/aiet-config.json new file mode 100644 index 0000000..5269409 --- /dev/null +++ b/tests/aiet/test_resources/applications/application5/aiet-config.json @@ -0,0 +1,160 @@ +[ + { + "name": "application_5", + "description": "This is application 5", + "build_dir": "default_build_dir", + "supported_systems": [ + { + "name": "System 1", + "lock": false + }, + { + "name": "System 2" + } + ], + "variables": { + "var1": "value1", + "var2": "value2" + }, + "lock": true, + "commands": { + "build": [ + "default build command" + ], + "run": [ + "default run command" + ] + }, + "user_params": { + "build": [], + "run": [] + } + }, + { + "name": "application_5A", + "description": "This is application 5A", + "supported_systems": [ + { + "name": "System 1", + "build_dir": "build_5A", + "variables": { + "var1": "new value1" + } + }, + { + "name": "System 2", + "variables": { + "var2": "new value2" + }, + "lock": true, + "commands": { + "run": [ + "run command on system 2" + ] + } + } + ], + "variables": { + "var1": "value1", + "var2": "value2" + }, + "build_dir": "build", + "commands": { + "build": [ + "default build command" + ], + "run": [ + "default run command" + ] + }, + "user_params": { + "build": [], + "run": [] + } + }, + { + "name": "application_5B", + "description": "This is application 5B", + "supported_systems": [ + { + "name": "System 1", + "build_dir": "build_5B", + "variables": { + "var1": "value for var1 System1", + "var2": "value for var2 System1" + }, + "user_params": { + "build": [ + { + "name": "--param_5B", + "description": "Sample command param", + "values": [ + "value1", + "value2", + "value3" + ], + "default_value": "value1", + "alias": "param1" + } + ] + } + }, + { + "name": "System 2", + "variables": { + "var1": "value for var1 System2", + "var2": "value for var2 System2" + }, + "commands": { + "build": [ + "build command on system 2 with {variables:var1} {user_params:param1}" + ], + "run": [ + "run command on system 2" + ] + }, + "user_params": { + "run": [] + } + } + ], + "build_dir": "build", + "commands": { + "build": [ + "default build command with {variables:var1}" + ], + "run": [ + "default run command with {variables:var2}" + ] + }, + "user_params": { + "build": [ + { + "name": "--param", + "description": "Sample command param", + "values": [ + "value1", + "value2", + "value3" + ], + "default_value": "value1", + "alias": "param1" + } + ], + "run": [], + "non_used_command": [ + { + "name": "--not-used", + "description": "Not used param anywhere", + "values": [ + "value1", + "value2", + "value3" + ], + "default_value": "value1", + "alias": "param1" + } + ] + } + } +] diff --git a/tests/aiet/test_resources/applications/application5/aiet-config.json.license b/tests/aiet/test_resources/applications/application5/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/applications/application5/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/readme.txt b/tests/aiet/test_resources/applications/readme.txt new file mode 100644 index 0000000..a1f8209 --- /dev/null +++ b/tests/aiet/test_resources/applications/readme.txt @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +Dummy file for test purposes diff --git a/tests/aiet/test_resources/hello_world.json b/tests/aiet/test_resources/hello_world.json new file mode 100644 index 0000000..8a9a448 --- /dev/null +++ b/tests/aiet/test_resources/hello_world.json @@ -0,0 +1,54 @@ +[ + { + "name": "Hello world", + "description": "Dummy application that displays 'Hello world!'", + "supported_systems": [ + "Dummy System" + ], + "build_dir": "build", + "deploy_data": [ + [ + "src", + "/tmp/" + ], + [ + "README", + "/tmp/README.md" + ] + ], + "commands": { + "clean": [], + "build": [], + "run": [ + "echo 'Hello world!'", + "ls -l /tmp" + ], + "post_run": [] + }, + "user_params": { + "run": [ + { + "name": "--choice-param", + "values": [ + "dummy_value_1", + "dummy_value_2" + ], + "default_value": "dummy_value_1", + "description": "Choice param" + }, + { + "name": "--open-param", + "values": [], + "default_value": "dummy_value_4", + "description": "Open param" + }, + { + "name": "--enable-flag", + "default_value": "dummy_value_4", + "description": "Flag param" + } + ], + "build": [] + } + } +] diff --git a/tests/aiet/test_resources/hello_world.json.license b/tests/aiet/test_resources/hello_world.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/hello_world.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/scripts/test_backend_run b/tests/aiet/test_resources/scripts/test_backend_run new file mode 100755 index 0000000..548f577 --- /dev/null +++ b/tests/aiet/test_resources/scripts/test_backend_run @@ -0,0 +1,8 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +echo "Hello from script" +>&2 echo "Oops!" +sleep 100 diff --git a/tests/aiet/test_resources/scripts/test_backend_run_script.sh b/tests/aiet/test_resources/scripts/test_backend_run_script.sh new file mode 100644 index 0000000..548f577 --- /dev/null +++ b/tests/aiet/test_resources/scripts/test_backend_run_script.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +echo "Hello from script" +>&2 echo "Oops!" +sleep 100 diff --git a/tests/aiet/test_resources/systems/system1/aiet-config.json b/tests/aiet/test_resources/systems/system1/aiet-config.json new file mode 100644 index 0000000..4b5dd19 --- /dev/null +++ b/tests/aiet/test_resources/systems/system1/aiet-config.json @@ -0,0 +1,35 @@ +[ + { + "name": "System 1", + "description": "This is system 1", + "build_dir": "build", + "data_transfer": { + "protocol": "ssh", + "username": "root", + "password": "root", + "hostname": "localhost", + "port": "8021" + }, + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ], + "deploy": [ + "echo 'deploy'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/systems/system1/aiet-config.json.license b/tests/aiet/test_resources/systems/system1/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/systems/system1/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt b/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt new file mode 100644 index 0000000..487e9d8 --- /dev/null +++ b/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt @@ -0,0 +1,2 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/systems/system2/aiet-config.json b/tests/aiet/test_resources/systems/system2/aiet-config.json new file mode 100644 index 0000000..a9e0eb3 --- /dev/null +++ b/tests/aiet/test_resources/systems/system2/aiet-config.json @@ -0,0 +1,32 @@ +[ + { + "name": "System 2", + "description": "This is system 2", + "build_dir": "build", + "data_transfer": { + "protocol": "ssh", + "username": "root", + "password": "root", + "hostname": "localhost", + "port": "8021" + }, + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/systems/system2/aiet-config.json.license b/tests/aiet/test_resources/systems/system2/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/systems/system2/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/systems/system3/readme.txt b/tests/aiet/test_resources/systems/system3/readme.txt new file mode 100644 index 0000000..aba5a9c --- /dev/null +++ b/tests/aiet/test_resources/systems/system3/readme.txt @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +This system does not have the json configuration file diff --git a/tests/aiet/test_resources/systems/system4/aiet-config.json b/tests/aiet/test_resources/systems/system4/aiet-config.json new file mode 100644 index 0000000..295e00f --- /dev/null +++ b/tests/aiet/test_resources/systems/system4/aiet-config.json @@ -0,0 +1,19 @@ +[ + { + "name": "System 4", + "description": "This is system 4", + "build_dir": "build", + "data_transfer": { + "protocol": "local" + }, + "commands": { + "run": [ + "echo {application.name}", + "cat {application.commands.run:0}" + ] + }, + "user_params": { + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/systems/system4/aiet-config.json.license b/tests/aiet/test_resources/systems/system4/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/systems/system4/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/tools/tool1/aiet-config.json b/tests/aiet/test_resources/tools/tool1/aiet-config.json new file mode 100644 index 0000000..067ef7e --- /dev/null +++ b/tests/aiet/test_resources/tools/tool1/aiet-config.json @@ -0,0 +1,30 @@ +[ + { + "name": "tool_1", + "description": "This is tool 1", + "build_dir": "build", + "supported_systems": [ + { + "name": "System 1" + } + ], + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/tools/tool1/aiet-config.json.license b/tests/aiet/test_resources/tools/tool1/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/tools/tool1/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/tools/tool2/aiet-config.json b/tests/aiet/test_resources/tools/tool2/aiet-config.json new file mode 100644 index 0000000..6eee9a6 --- /dev/null +++ b/tests/aiet/test_resources/tools/tool2/aiet-config.json @@ -0,0 +1,26 @@ +[ + { + "name": "tool_2", + "description": "This is tool 2 with no supported systems", + "build_dir": "build", + "supported_systems": [], + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/tools/tool2/aiet-config.json.license b/tests/aiet/test_resources/tools/tool2/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/tools/tool2/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json @@ -0,0 +1 @@ +[] diff --git a/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json new file mode 100644 index 0000000..ff1cf1a --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json @@ -0,0 +1,35 @@ +[ + { + "name": "test_application", + "description": "This is test_application", + "build_dir": "build", + "supported_systems": [ + { + "name": "System 4" + } + ], + "commands": { + "build": [ + "cp ../hello_app.txt ." + ], + "run": [ + "{application.build_dir}/hello_app.txt" + ] + }, + "user_params": { + "build": [ + { + "name": "--app", + "description": "Sample command param", + "values": [ + "application1", + "application2", + "application3" + ], + "default_value": "application1" + } + ], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json new file mode 100644 index 0000000..724b31b --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json @@ -0,0 +1,2 @@ +This is not valid json file +{ diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json new file mode 100644 index 0000000..1ebb29c --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json @@ -0,0 +1,30 @@ +[ + { + "name": "test_application", + "description": "This is test_application", + "build_dir": "build", + "commands": { + "build": [ + "cp ../hello_app.txt ." + ], + "run": [ + "{application.build_dir}/hello_app.txt" + ] + }, + "user_params": { + "build": [ + { + "name": "--app", + "description": "Sample command param", + "values": [ + "application1", + "application2", + "application3" + ], + "default_value": "application1" + } + ], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json new file mode 100644 index 0000000..410d12d --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json @@ -0,0 +1,35 @@ +[ + { + "name": "test_application", + "description": "This is test_application", + "build_dir": "build", + "supported_systems": [ + { + "anme": "System 4" + } + ], + "commands": { + "build": [ + "cp ../hello_app.txt ." + ], + "run": [ + "{application.build_dir}/hello_app.txt" + ] + }, + "user_params": { + "build": [ + { + "name": "--app", + "description": "Sample command param", + "values": [ + "application1", + "application2", + "application3" + ], + "default_value": "application1" + } + ], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json @@ -0,0 +1 @@ +[] diff --git a/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json new file mode 100644 index 0000000..20142e9 --- /dev/null +++ b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json @@ -0,0 +1,16 @@ +[ + { + "name": "Test system", + "description": "This is a test system", + "build_dir": "build", + "data_transfer": { + "protocol": "local" + }, + "commands": { + "run": [] + }, + "user_params": { + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_run_vela_script.py b/tests/aiet/test_run_vela_script.py new file mode 100644 index 0000000..971856e --- /dev/null +++ b/tests/aiet/test_run_vela_script.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=redefined-outer-name,no-self-use +"""Module for testing run_vela.py script.""" +from pathlib import Path +from typing import Any +from typing import List + +import pytest +from click.testing import CliRunner + +from aiet.cli.common import MiddlewareExitCode +from aiet.resources.tools.vela.check_model import get_model_from_file +from aiet.resources.tools.vela.check_model import is_vela_optimised +from aiet.resources.tools.vela.run_vela import run_vela + + +@pytest.fixture(scope="session") +def vela_config_path(test_tools_path: Path) -> Path: + """Return test systems path in a pytest fixture.""" + return test_tools_path / "vela" / "vela.ini" + + +@pytest.fixture( + params=[ + ["ethos-u65-256", "Ethos_U65_High_End", "U65_Shared_Sram"], + ["ethos-u55-32", "Ethos_U55_High_End_Embedded", "U55_Shared_Sram"], + ] +) +def ethos_config(request: Any) -> Any: + """Fixture to provide different configuration for Ethos-U optimization with Vela.""" + return request.param + + +# pylint: disable=too-many-arguments +def generate_args( + input_: Path, + output: Path, + cfg: Path, + acc_config: str, + system_config: str, + memory_mode: str, +) -> List[str]: + """Generate arguments that can be passed to script 'run_vela'.""" + return [ + "-i", + str(input_), + "-o", + str(output), + "--config", + str(cfg), + "--accelerator-config", + acc_config, + "--system-config", + system_config, + "--memory-mode", + memory_mode, + "--optimise", + "Performance", + ] + + +def check_run_vela( + cli_runner: CliRunner, args: List, expected_success: bool, output_file: Path +) -> None: + """Run Vela with the given arguments and check the result.""" + result = cli_runner.invoke(run_vela, args) + success = result.exit_code == MiddlewareExitCode.SUCCESS + assert success == expected_success + if success: + model = get_model_from_file(output_file) + assert is_vela_optimised(model) + + +def run_vela_script( + cli_runner: CliRunner, + input_model_file: Path, + output_model_file: Path, + vela_config: Path, + expected_success: bool, + acc_config: str, + system_config: str, + memory_mode: str, +) -> None: + """Run the command 'run_vela' on the command line.""" + args = generate_args( + input_model_file, + output_model_file, + vela_config, + acc_config, + system_config, + memory_mode, + ) + check_run_vela(cli_runner, args, expected_success, output_model_file) + + +class TestRunVelaCli: + """Test the command-line execution of the run_vela command.""" + + def test_non_optimised_model( + self, + cli_runner: CliRunner, + non_optimised_input_model_file: Path, + tmp_path: Path, + vela_config_path: Path, + ethos_config: List, + ) -> None: + """Verify Vela is run correctly on an unoptimised model.""" + run_vela_script( + cli_runner, + non_optimised_input_model_file, + tmp_path / "test.tflite", + vela_config_path, + True, + *ethos_config, + ) + + def test_optimised_model( + self, + cli_runner: CliRunner, + optimised_input_model_file: Path, + tmp_path: Path, + vela_config_path: Path, + ethos_config: List, + ) -> None: + """Verify Vela is run correctly on an already optimised model.""" + run_vela_script( + cli_runner, + optimised_input_model_file, + tmp_path / "test.tflite", + vela_config_path, + True, + *ethos_config, + ) + + def test_invalid_model( + self, + cli_runner: CliRunner, + invalid_input_model_file: Path, + tmp_path: Path, + vela_config_path: Path, + ethos_config: List, + ) -> None: + """Verify an error is raised when the input model is not valid.""" + run_vela_script( + cli_runner, + invalid_input_model_file, + tmp_path / "test.tflite", + vela_config_path, + False, + *ethos_config, + ) diff --git a/tests/aiet/test_utils_fs.py b/tests/aiet/test_utils_fs.py new file mode 100644 index 0000000..46d276e --- /dev/null +++ b/tests/aiet/test_utils_fs.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Module for testing fs.py.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Union +from unittest.mock import MagicMock + +import pytest + +from aiet.utils.fs import get_resources +from aiet.utils.fs import read_file_as_bytearray +from aiet.utils.fs import read_file_as_string +from aiet.utils.fs import recreate_directory +from aiet.utils.fs import remove_directory +from aiet.utils.fs import remove_resource +from aiet.utils.fs import ResourceType +from aiet.utils.fs import valid_for_filename + + +@pytest.mark.parametrize( + "resource_name,expected_path", + [ + ("systems", does_not_raise()), + ("applications", does_not_raise()), + ("whaaat", pytest.raises(ResourceWarning)), + (None, pytest.raises(ResourceWarning)), + ], +) +def test_get_resources(resource_name: ResourceType, expected_path: Any) -> None: + """Test get_resources() with multiple parameters.""" + with expected_path: + resource_path = get_resources(resource_name) + assert resource_path.exists() + + +def test_remove_resource_wrong_directory( + monkeypatch: Any, test_applications_path: Path +) -> None: + """Test removing resource with wrong directory.""" + mock_get_resources = MagicMock(return_value=test_applications_path) + monkeypatch.setattr("aiet.utils.fs.get_resources", mock_get_resources) + + mock_shutil_rmtree = MagicMock() + monkeypatch.setattr("aiet.utils.fs.shutil.rmtree", mock_shutil_rmtree) + + with pytest.raises(Exception, match="Resource .* does not exist"): + remove_resource("unknown", "applications") + mock_shutil_rmtree.assert_not_called() + + with pytest.raises(Exception, match="Wrong resource .*"): + remove_resource("readme.txt", "applications") + mock_shutil_rmtree.assert_not_called() + + +def test_remove_resource(monkeypatch: Any, test_applications_path: Path) -> None: + """Test removing resource data.""" + mock_get_resources = MagicMock(return_value=test_applications_path) + monkeypatch.setattr("aiet.utils.fs.get_resources", mock_get_resources) + + mock_shutil_rmtree = MagicMock() + monkeypatch.setattr("aiet.utils.fs.shutil.rmtree", mock_shutil_rmtree) + + remove_resource("application1", "applications") + mock_shutil_rmtree.assert_called_once() + + +def test_remove_directory(tmpdir: Any) -> None: + """Test directory removal.""" + tmpdir_path = Path(tmpdir) + tmpfile = tmpdir_path / "temp.txt" + + for item in [None, tmpfile]: + with pytest.raises(Exception, match="No directory path provided"): + remove_directory(item) + + newdir = tmpdir_path / "newdir" + newdir.mkdir() + + assert newdir.is_dir() + remove_directory(newdir) + assert not newdir.exists() + + +def test_recreate_directory(tmpdir: Any) -> None: + """Test directory recreation.""" + with pytest.raises(Exception, match="No directory path provided"): + recreate_directory(None) + + tmpdir_path = Path(tmpdir) + tmpfile = tmpdir_path / "temp.txt" + tmpfile.touch() + with pytest.raises(Exception, match="Path .* does exist and it is not a directory"): + recreate_directory(tmpfile) + + newdir = tmpdir_path / "newdir" + newdir.mkdir() + newfile = newdir / "newfile" + newfile.touch() + assert list(newdir.iterdir()) == [newfile] + recreate_directory(newdir) + assert not list(newdir.iterdir()) + + newdir2 = tmpdir_path / "newdir2" + assert not newdir2.exists() + recreate_directory(newdir2) + assert newdir2.is_dir() + + +def write_to_file( + write_directory: Any, write_mode: str, write_text: Union[str, bytes] +) -> Path: + """Write some text to a temporary test file.""" + tmpdir_path = Path(write_directory) + tmpfile = tmpdir_path / "file_name.txt" + with open(tmpfile, write_mode) as file: # pylint: disable=unspecified-encoding + file.write(write_text) + return tmpfile + + +class TestReadFileAsString: + """Test read_file_as_string() function.""" + + def test_returns_text_from_valid_file(self, tmpdir: Any) -> None: + """Ensure the string written to a file read correctly.""" + file_path = write_to_file(tmpdir, "w", "hello") + assert read_file_as_string(file_path) == "hello" + + def test_output_is_empty_string_when_input_file_non_existent( + self, tmpdir: Any + ) -> None: + """Ensure empty string returned when reading from non-existent file.""" + file_path = Path(tmpdir / "non-existent.txt") + assert read_file_as_string(file_path) == "" + + +class TestReadFileAsByteArray: + """Test read_file_as_bytearray() function.""" + + def test_returns_bytes_from_valid_file(self, tmpdir: Any) -> None: + """Ensure the bytes written to a file read correctly.""" + file_path = write_to_file(tmpdir, "wb", b"hello bytes") + assert read_file_as_bytearray(file_path) == b"hello bytes" + + def test_output_is_empty_bytearray_when_input_file_non_existent( + self, tmpdir: Any + ) -> None: + """Ensure empty bytearray returned when reading from non-existent file.""" + file_path = Path(tmpdir / "non-existent.txt") + assert read_file_as_bytearray(file_path) == bytearray() + + +@pytest.mark.parametrize( + "value, replacement, expected_result", + [ + ["", "", ""], + ["123", "", "123"], + ["123", "_", "123"], + ["/some_folder/some_script.sh", "", "some_foldersome_script.sh"], + ["/some_folder/some_script.sh", "_", "_some_folder_some_script.sh"], + ["!;'some_name$%^!", "_", "___some_name____"], + ], +) +def test_valid_for_filename(value: str, replacement: str, expected_result: str) -> None: + """Test function valid_for_filename.""" + assert valid_for_filename(value, replacement) == expected_result diff --git a/tests/aiet/test_utils_helpers.py b/tests/aiet/test_utils_helpers.py new file mode 100644 index 0000000..bbe03fc --- /dev/null +++ b/tests/aiet/test_utils_helpers.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for testing helpers.py.""" +import logging +from typing import Any +from typing import List +from unittest.mock import call +from unittest.mock import MagicMock + +import pytest + +from aiet.utils.helpers import set_verbosity + + +@pytest.mark.parametrize( + "verbosity,expected_calls", + [(0, []), (1, [call(logging.INFO)]), (2, [call(logging.DEBUG)])], +) +def test_set_verbosity( + verbosity: int, expected_calls: List[Any], monkeypatch: Any +) -> None: + """Test set_verbosity() with different verbsosity levels.""" + with monkeypatch.context() as mock_context: + logging_mock = MagicMock() + mock_context.setattr(logging.getLogger(), "setLevel", logging_mock) + set_verbosity(None, None, verbosity) + logging_mock.assert_has_calls(expected_calls) diff --git a/tests/aiet/test_utils_proc.py b/tests/aiet/test_utils_proc.py new file mode 100644 index 0000000..9fb48dd --- /dev/null +++ b/tests/aiet/test_utils_proc.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=attribute-defined-outside-init,no-self-use,not-callable +"""Pytests for testing aiet/utils/proc.py.""" +from pathlib import Path +from typing import Any +from unittest import mock + +import psutil +import pytest +from sh import ErrorReturnCode + +from aiet.utils.proc import Command +from aiet.utils.proc import CommandFailedException +from aiet.utils.proc import CommandNotFound +from aiet.utils.proc import parse_command +from aiet.utils.proc import print_command_stdout +from aiet.utils.proc import run_and_wait +from aiet.utils.proc import save_process_info +from aiet.utils.proc import ShellCommand +from aiet.utils.proc import terminate_command +from aiet.utils.proc import terminate_external_process + + +class TestShellCommand: + """Sample class for collecting tests.""" + + def test_shellcommand_default_value(self) -> None: + """Test the instantiation of the class ShellCommand with no parameter.""" + shell_command = ShellCommand() + assert shell_command.base_log_path == "/tmp" + + @pytest.mark.parametrize( + "base_log_path,expected", [("/test", "/test"), ("/asd", "/asd")] + ) + def test_shellcommand_with_param(self, base_log_path: str, expected: str) -> None: + """Test init ShellCommand with different parameters.""" + shell_command = ShellCommand(base_log_path) + assert shell_command.base_log_path == expected + + def test_run_ls(self, monkeypatch: Any) -> None: + """Test a simple ls command.""" + mock_command = mock.MagicMock() + monkeypatch.setattr(Command, "bake", mock_command) + + mock_get_stdout_stderr_paths = mock.MagicMock() + mock_get_stdout_stderr_paths.return_value = ("/tmp/std.out", "/tmp/std.err") + monkeypatch.setattr( + ShellCommand, "get_stdout_stderr_paths", mock_get_stdout_stderr_paths + ) + + shell_command = ShellCommand() + shell_command.run("ls", "-l") + assert mock_command.mock_calls[0] == mock.call(("-l",)) + assert mock_command.mock_calls[1] == mock.call()( + _bg=True, _err="/tmp/std.err", _out="/tmp/std.out", _tee=True, _bg_exc=False + ) + + def test_run_command_not_found(self) -> None: + """Test whe the command doesn't exist.""" + shell_command = ShellCommand() + with pytest.raises(CommandNotFound): + shell_command.run("lsl", "-l") + + def test_get_stdout_stderr_paths_valid_path(self) -> None: + """Test the method to get files to store stdout and stderr.""" + valid_path = "/tmp" + shell_command = ShellCommand(valid_path) + out, err = shell_command.get_stdout_stderr_paths(valid_path, "cmd") + assert out.exists() and out.is_file() + assert err.exists() and err.is_file() + assert "cmd" in out.name + assert "cmd" in err.name + + def test_get_stdout_stderr_paths_not_invalid_path(self) -> None: + """Test the method to get output files with an invalid path.""" + invalid_path = "/invalid/foo/bar" + shell_command = ShellCommand(invalid_path) + with pytest.raises(FileNotFoundError): + shell_command.get_stdout_stderr_paths(invalid_path, "cmd") + + +@mock.patch("builtins.print") +def test_print_command_stdout_alive(mock_print: Any) -> None: + """Test the print command stdout with an alive (running) process.""" + mock_command = mock.MagicMock() + mock_command.is_alive.return_value = True + mock_command.next.side_effect = ["test1", "test2", StopIteration] + + print_command_stdout(mock_command) + + mock_command.assert_has_calls( + [mock.call.is_alive(), mock.call.next(), mock.call.next()] + ) + mock_print.assert_has_calls( + [mock.call("test1", end=""), mock.call("test2", end="")] + ) + + +@mock.patch("builtins.print") +def test_print_command_stdout_not_alive(mock_print: Any) -> None: + """Test the print command stdout with a not alive (exited) process.""" + mock_command = mock.MagicMock() + mock_command.is_alive.return_value = False + mock_command.stdout = "test" + + print_command_stdout(mock_command) + mock_command.assert_has_calls([mock.call.is_alive()]) + mock_print.assert_called_once_with("test") + + +def test_terminate_external_process_no_process(capsys: Any) -> None: + """Test that non existed process could be terminated.""" + mock_command = mock.MagicMock() + mock_command.terminate.side_effect = psutil.Error("Error!") + + terminate_external_process(mock_command) + captured = capsys.readouterr() + assert captured.out == "Unable to terminate process\n" + + +def test_terminate_external_process_case1() -> None: + """Test when process terminated immediately.""" + mock_command = mock.MagicMock() + mock_command.is_running.return_value = False + + terminate_external_process(mock_command) + mock_command.terminate.assert_called_once() + mock_command.is_running.assert_called_once() + + +def test_terminate_external_process_case2() -> None: + """Test when process termination takes time.""" + mock_command = mock.MagicMock() + mock_command.is_running.side_effect = [True, True, False] + + terminate_external_process(mock_command) + mock_command.terminate.assert_called_once() + assert mock_command.is_running.call_count == 3 + + +def test_terminate_external_process_case3() -> None: + """Test when process termination takes more time.""" + mock_command = mock.MagicMock() + mock_command.is_running.side_effect = [True, True, True] + + terminate_external_process( + mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1 + ) + assert mock_command.is_running.call_count == 3 + assert mock_command.terminate.call_count == 2 + + +def test_terminate_external_process_case4() -> None: + """Test when process termination takes more time.""" + mock_command = mock.MagicMock() + mock_command.is_running.side_effect = [True, True, False] + + terminate_external_process( + mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1 + ) + mock_command.terminate.assert_called_once() + assert mock_command.is_running.call_count == 3 + assert mock_command.terminate.call_count == 1 + + +def test_terminate_command_no_process() -> None: + """Test command termination when process does not exist.""" + mock_command = mock.MagicMock() + mock_command.process.signal_group.side_effect = ProcessLookupError() + + terminate_command(mock_command) + mock_command.process.signal_group.assert_called_once() + mock_command.is_alive.assert_not_called() + + +def test_terminate_command() -> None: + """Test command termination.""" + mock_command = mock.MagicMock() + mock_command.is_alive.return_value = False + + terminate_command(mock_command) + mock_command.process.signal_group.assert_called_once() + + +def test_terminate_command_case1() -> None: + """Test command termination when it takes time..""" + mock_command = mock.MagicMock() + mock_command.is_alive.side_effect = [True, True, False] + + terminate_command(mock_command, wait_period=0.1) + mock_command.process.signal_group.assert_called_once() + assert mock_command.is_alive.call_count == 3 + + +def test_terminate_command_case2() -> None: + """Test command termination when it takes much time..""" + mock_command = mock.MagicMock() + mock_command.is_alive.side_effect = [True, True, True] + + terminate_command(mock_command, number_of_attempts=3, wait_period=0.1) + assert mock_command.is_alive.call_count == 3 + assert mock_command.process.signal_group.call_count == 2 + + +class TestRunAndWait: + """Test run_and_wait function.""" + + @pytest.fixture(autouse=True) + def setup_method(self, monkeypatch: Any) -> None: + """Init test method.""" + self.execute_command_mock = mock.MagicMock() + monkeypatch.setattr( + "aiet.utils.proc.execute_command", self.execute_command_mock + ) + + self.terminate_command_mock = mock.MagicMock() + monkeypatch.setattr( + "aiet.utils.proc.terminate_command", self.terminate_command_mock + ) + + def test_if_execute_command_raises_exception(self) -> None: + """Test if execute_command fails.""" + self.execute_command_mock.side_effect = Exception("Error!") + with pytest.raises(Exception, match="Error!"): + run_and_wait("command", Path.cwd()) + + def test_if_command_finishes_with_error(self) -> None: + """Test if command finishes with error.""" + cmd_mock = mock.MagicMock() + self.execute_command_mock.return_value = cmd_mock + exit_code_mock = mock.PropertyMock( + side_effect=ErrorReturnCode("cmd", bytearray(), bytearray()) + ) + type(cmd_mock).exit_code = exit_code_mock + + with pytest.raises(CommandFailedException): + run_and_wait("command", Path.cwd()) + + @pytest.mark.parametrize("terminate_on_error, call_count", ((False, 0), (True, 1))) + def test_if_command_finishes_with_exception( + self, terminate_on_error: bool, call_count: int + ) -> None: + """Test if command finishes with error.""" + cmd_mock = mock.MagicMock() + self.execute_command_mock.return_value = cmd_mock + exit_code_mock = mock.PropertyMock(side_effect=Exception("Error!")) + type(cmd_mock).exit_code = exit_code_mock + + with pytest.raises(Exception, match="Error!"): + run_and_wait("command", Path.cwd(), terminate_on_error=terminate_on_error) + + assert self.terminate_command_mock.call_count == call_count + + +def test_save_process_info_no_process(monkeypatch: Any, tmpdir: Any) -> None: + """Test save_process_info function.""" + mock_process = mock.MagicMock() + monkeypatch.setattr("psutil.Process", mock.MagicMock(return_value=mock_process)) + mock_process.children.side_effect = psutil.NoSuchProcess(555) + + pid_file_path = Path(tmpdir) / "test.pid" + save_process_info(555, pid_file_path) + assert not pid_file_path.exists() + + +def test_parse_command() -> None: + """Test parse_command function.""" + assert parse_command("1.sh") == ["bash", "1.sh"] + assert parse_command("1.sh", shell="sh") == ["sh", "1.sh"] + assert parse_command("command") == ["command"] + assert parse_command("command 123 --param=1") == ["command", "123", "--param=1"] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5c6156c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Pytest conf module.""" +import shutil +from pathlib import Path +from typing import Generator + +import pytest +import tensorflow as tf + +from mlia.devices.ethosu.config import EthosUConfiguration +from mlia.nn.tensorflow.utils import convert_to_tflite +from mlia.nn.tensorflow.utils import save_keras_model +from mlia.nn.tensorflow.utils import save_tflite_model +from mlia.tools.vela_wrapper import optimize_model + + +def get_test_keras_model() -> tf.keras.Model: + """Return test Keras model.""" + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=(28, 28, 1), batch_size=1, name="input"), + tf.keras.layers.Reshape((28, 28, 1)), + tf.keras.layers.Conv2D( + filters=12, kernel_size=(3, 3), activation="relu", name="conv1" + ), + tf.keras.layers.Conv2D( + filters=12, kernel_size=(3, 3), activation="relu", name="conv2" + ), + tf.keras.layers.MaxPool2D(2, 2), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(10, name="output"), + ] + ) + + model.compile(optimizer="sgd", loss="mean_squared_error") + return model + + +@pytest.fixture(scope="session", name="test_models_path") +def fixture_test_models_path( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Provide path to the test models.""" + tmp_path = tmp_path_factory.mktemp("models") + + keras_model = get_test_keras_model() + save_keras_model(keras_model, tmp_path / "test_model.h5") + + tflite_model = convert_to_tflite(keras_model, quantized=True) + tflite_model_path = tmp_path / "test_model.tflite" + save_tflite_model(tflite_model, tflite_model_path) + + tflite_vela_model = tmp_path / "test_model_vela.tflite" + device = EthosUConfiguration("ethos-u55-256") + optimize_model(tflite_model_path, device.compiler_options, tflite_vela_model) + + tf.saved_model.save(keras_model, str(tmp_path / "tf_model_test_model")) + + invalid_tflite_model = tmp_path / "invalid.tflite" + invalid_tflite_model.touch() + + yield tmp_path + + shutil.rmtree(tmp_path) + + +@pytest.fixture(scope="session", name="test_keras_model") +def fixture_test_keras_model(test_models_path: Path) -> Path: + """Return test Keras model.""" + return test_models_path / "test_model.h5" + + +@pytest.fixture(scope="session", name="test_tflite_model") +def fixture_test_tflite_model(test_models_path: Path) -> Path: + """Return test TFLite model.""" + return test_models_path / "test_model.tflite" + + +@pytest.fixture(scope="session", name="test_tflite_vela_model") +def fixture_test_tflite_vela_model(test_models_path: Path) -> Path: + """Return test Vela-optimized TFLite model.""" + return test_models_path / "test_model_vela.tflite" + + +@pytest.fixture(scope="session", name="test_tf_model") +def fixture_test_tf_model(test_models_path: Path) -> Path: + """Return test TFLite model.""" + return test_models_path / "tf_model_test_model" + + +@pytest.fixture(scope="session", name="test_tflite_invalid_model") +def fixture_test_tflite_invalid_model(test_models_path: Path) -> Path: + """Return test invalid TFLite model.""" + return test_models_path / "invalid.tflite" diff --git a/tests/mlia/__init__.py b/tests/mlia/__init__.py new file mode 100644 index 0000000..0687f14 --- /dev/null +++ b/tests/mlia/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""MLIA tests module.""" diff --git a/tests/mlia/conftest.py b/tests/mlia/conftest.py new file mode 100644 index 0000000..f683fca --- /dev/null +++ b/tests/mlia/conftest.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Pytest conf module.""" +from pathlib import Path + +import pytest + +from mlia.core.context import ExecutionContext + + +@pytest.fixture(scope="session", name="test_resources_path") +def fixture_test_resources_path() -> Path: + """Return test resources path.""" + return Path(__file__).parent / "test_resources" + + +@pytest.fixture(name="dummy_context") +def fixture_dummy_context(tmpdir: str) -> ExecutionContext: + """Return dummy context fixture.""" + return ExecutionContext(working_dir=tmpdir) diff --git a/tests/mlia/test_api.py b/tests/mlia/test_api.py new file mode 100644 index 0000000..54d4796 --- /dev/null +++ b/tests/mlia/test_api.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the API functions.""" +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from mlia.api import get_advice +from mlia.core.common import AdviceCategory +from mlia.core.context import Context +from mlia.core.context import ExecutionContext + + +def test_get_advice_no_target_provided(test_keras_model: Path) -> None: + """Test getting advice when no target provided.""" + with pytest.raises(Exception, match="Target is not provided"): + get_advice(None, test_keras_model, "all") # type: ignore + + +def test_get_advice_wrong_category(test_keras_model: Path) -> None: + """Test getting advice when wrong advice category provided.""" + with pytest.raises(Exception, match="Invalid advice category unknown"): + get_advice("ethos-u55-256", test_keras_model, "unknown") # type: ignore + + +@pytest.mark.parametrize( + "category, context, expected_category", + [ + [ + "all", + None, + AdviceCategory.ALL, + ], + [ + "optimization", + None, + AdviceCategory.OPTIMIZATION, + ], + [ + "operators", + None, + AdviceCategory.OPERATORS, + ], + [ + "performance", + None, + AdviceCategory.PERFORMANCE, + ], + [ + "all", + ExecutionContext(), + AdviceCategory.ALL, + ], + [ + "all", + ExecutionContext(advice_category=AdviceCategory.PERFORMANCE), + AdviceCategory.PERFORMANCE, + ], + [ + "all", + ExecutionContext(config_parameters={"param": "value"}), + AdviceCategory.ALL, + ], + [ + "all", + ExecutionContext(event_handlers=[MagicMock()]), + AdviceCategory.ALL, + ], + ], +) +def test_get_advice( + monkeypatch: pytest.MonkeyPatch, + category: str, + context: ExecutionContext, + expected_category: AdviceCategory, + test_keras_model: Path, +) -> None: + """Test getting advice with valid parameters.""" + advisor_mock = MagicMock() + monkeypatch.setattr("mlia.api._get_advisor", MagicMock(return_value=advisor_mock)) + + get_advice( + "ethos-u55-256", + test_keras_model, + category, # type: ignore + context=context, + ) + + advisor_mock.run.assert_called_once() + context = advisor_mock.run.mock_calls[0].args[0] + assert isinstance(context, Context) + assert context.advice_category == expected_category + + assert context.event_handlers is not None + assert context.config_parameters is not None diff --git a/tests/mlia/test_cli_commands.py b/tests/mlia/test_cli_commands.py new file mode 100644 index 0000000..bf17339 --- /dev/null +++ b/tests/mlia/test_cli_commands.py @@ -0,0 +1,204 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for cli.commands module.""" +from pathlib import Path +from typing import Any +from typing import Optional +from unittest.mock import call +from unittest.mock import MagicMock + +import pytest + +from mlia.cli.commands import backend +from mlia.cli.commands import operators +from mlia.cli.commands import optimization +from mlia.cli.commands import performance +from mlia.core.context import ExecutionContext +from mlia.devices.ethosu.config import EthosUConfiguration +from mlia.devices.ethosu.performance import MemoryUsage +from mlia.devices.ethosu.performance import NPUCycles +from mlia.devices.ethosu.performance import PerformanceMetrics +from mlia.tools.metadata.common import InstallationManager + + +def test_operators_expected_parameters(dummy_context: ExecutionContext) -> None: + """Test operators command wrong parameters.""" + with pytest.raises(Exception, match="Model is not provided"): + operators(dummy_context, "ethos-u55-256") + + +def test_performance_unknown_target( + dummy_context: ExecutionContext, test_tflite_model: Path +) -> None: + """Test that command should fail if unknown target passed.""" + with pytest.raises(Exception, match="Unable to find target profile unknown"): + performance( + dummy_context, model=str(test_tflite_model), target_profile="unknown" + ) + + +@pytest.mark.parametrize( + "target_profile, optimization_type, optimization_target, expected_error", + [ + [ + "ethos-u55-256", + None, + "0.5", + pytest.raises(Exception, match="Optimization type is not provided"), + ], + [ + "ethos-u65-512", + "unknown", + "16", + pytest.raises(Exception, match="Unsupported optimization type: unknown"), + ], + [ + "ethos-u55-256", + "pruning", + None, + pytest.raises(Exception, match="Optimization target is not provided"), + ], + [ + "ethos-u65-512", + "clustering", + None, + pytest.raises(Exception, match="Optimization target is not provided"), + ], + [ + "unknown", + "clustering", + "16", + pytest.raises(Exception, match="Unable to find target profile unknown"), + ], + ], +) +def test_opt_expected_parameters( + dummy_context: ExecutionContext, + target_profile: str, + monkeypatch: pytest.MonkeyPatch, + optimization_type: str, + optimization_target: str, + expected_error: Any, + test_keras_model: Path, +) -> None: + """Test that command should fail if no or unknown optimization type provided.""" + mock_performance_estimation(monkeypatch) + + with expected_error: + optimization( + ctx=dummy_context, + target_profile=target_profile, + model=str(test_keras_model), + optimization_type=optimization_type, + optimization_target=optimization_target, + ) + + +@pytest.mark.parametrize( + "target_profile, optimization_type, optimization_target", + [ + ["ethos-u55-256", "pruning", "0.5"], + ["ethos-u65-512", "clustering", "32"], + ["ethos-u55-256", "pruning,clustering", "0.5,32"], + ], +) +def test_opt_valid_optimization_target( + target_profile: str, + dummy_context: ExecutionContext, + optimization_type: str, + optimization_target: str, + monkeypatch: pytest.MonkeyPatch, + test_keras_model: Path, +) -> None: + """Test that command should not fail with valid optimization targets.""" + mock_performance_estimation(monkeypatch) + + optimization( + ctx=dummy_context, + target_profile=target_profile, + model=str(test_keras_model), + optimization_type=optimization_type, + optimization_target=optimization_target, + ) + + +def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None: + """Mock performance estimation.""" + metrics = PerformanceMetrics( + EthosUConfiguration("ethos-u55-256"), + NPUCycles(1, 2, 3, 4, 5, 6), + MemoryUsage(1, 2, 3, 4, 5), + ) + monkeypatch.setattr( + "mlia.devices.ethosu.data_collection.EthosUPerformanceEstimator.estimate", + MagicMock(return_value=metrics), + ) + + +@pytest.fixture(name="installation_manager_mock") +def fixture_mock_installation_manager(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + """Mock installation manager.""" + install_manager_mock = MagicMock(spec=InstallationManager) + monkeypatch.setattr( + "mlia.cli.commands.get_installation_manager", + MagicMock(return_value=install_manager_mock), + ) + return install_manager_mock + + +def test_backend_command_action_status(installation_manager_mock: MagicMock) -> None: + """Test backend command "status".""" + backend(backend_action="status") + + installation_manager_mock.show_env_details.assert_called_once() + + +@pytest.mark.parametrize( + "i_agree_to_the_contained_eula, backend_name, expected_calls", + [ + [False, None, [call(None, True)]], + [True, None, [call(None, False)]], + [False, "backend_name", [call("backend_name", True)]], + [True, "backend_name", [call("backend_name", False)]], + ], +) +def test_backend_command_action_add_downoad( + installation_manager_mock: MagicMock, + i_agree_to_the_contained_eula: bool, + backend_name: Optional[str], + expected_calls: Any, +) -> None: + """Test backend command "install" with download option.""" + backend( + backend_action="install", + download=True, + name=backend_name, + i_agree_to_the_contained_eula=i_agree_to_the_contained_eula, + ) + + assert installation_manager_mock.download_and_install.mock_calls == expected_calls + + +@pytest.mark.parametrize("backend_name", [None, "backend_name"]) +def test_backend_command_action_install_from_path( + installation_manager_mock: MagicMock, + tmp_path: Path, + backend_name: Optional[str], +) -> None: + """Test backend command "install" with backend path.""" + backend(backend_action="install", path=tmp_path, name=backend_name) + + installation_manager_mock.install_from(tmp_path, backend_name) + + +def test_backend_command_action_install_only_one_action( + installation_manager_mock: MagicMock, # pylint: disable=unused-argument + tmp_path: Path, +) -> None: + """Test that only one of action type allowed.""" + with pytest.raises( + Exception, + match="Please select only one action: download or " + "provide path to the backend installation", + ): + backend(backend_action="install", download=True, path=tmp_path) diff --git a/tests/mlia/test_cli_config.py b/tests/mlia/test_cli_config.py new file mode 100644 index 0000000..6d19eec --- /dev/null +++ b/tests/mlia/test_cli_config.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for cli.config module.""" +from typing import List +from unittest.mock import MagicMock + +import pytest + +from mlia.cli.config import get_default_backends +from mlia.cli.config import is_corstone_backend + + +@pytest.mark.parametrize( + "available_backends, expected_default_backends", + [ + [["Vela"], ["Vela"]], + [["Corstone-300"], ["Corstone-300"]], + [["Corstone-310"], ["Corstone-310"]], + [["Corstone-300", "Corstone-310"], ["Corstone-310"]], + [["Vela", "Corstone-300", "Corstone-310"], ["Vela", "Corstone-310"]], + [ + ["Vela", "Corstone-300", "Corstone-310", "New backend"], + ["Vela", "Corstone-310", "New backend"], + ], + [ + ["Vela", "Corstone-300", "New backend"], + ["Vela", "Corstone-300", "New backend"], + ], + ], +) +def test_get_default_backends( + monkeypatch: pytest.MonkeyPatch, + available_backends: List[str], + expected_default_backends: List[str], +) -> None: + """Test function get_default backends.""" + monkeypatch.setattr( + "mlia.cli.config.get_available_backends", + MagicMock(return_value=available_backends), + ) + + assert get_default_backends() == expected_default_backends + + +def test_is_corstone_backend() -> None: + """Test function is_corstone_backend.""" + assert is_corstone_backend("Corstone-300") is True + assert is_corstone_backend("Corstone-310") is True + assert is_corstone_backend("New backend") is False diff --git a/tests/mlia/test_cli_helpers.py b/tests/mlia/test_cli_helpers.py new file mode 100644 index 0000000..2c52885 --- /dev/null +++ b/tests/mlia/test_cli_helpers.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the helper classes.""" +from typing import Any +from typing import Dict +from typing import List + +import pytest + +from mlia.cli.helpers import CLIActionResolver +from mlia.nn.tensorflow.optimizations.select import OptimizationSettings + + +class TestCliActionResolver: + """Test cli action resolver.""" + + @staticmethod + @pytest.mark.parametrize( + "args, params, expected_result", + [ + [ + {}, + {"opt_settings": "some_setting"}, + [], + ], + [ + {}, + {}, + [ + "Note: you will need a Keras model for that.", + "For example: mlia optimization --optimization-type " + "pruning,clustering --optimization-target 0.5,32 " + "/path/to/keras_model", + "For more info: mlia optimization --help", + ], + ], + [ + {"model": "model.h5"}, + {}, + [ + "For example: mlia optimization --optimization-type " + "pruning,clustering --optimization-target 0.5,32 model.h5", + "For more info: mlia optimization --help", + ], + ], + [ + {"model": "model.h5"}, + {"opt_settings": [OptimizationSettings("pruning", 0.5, None)]}, + [ + "For more info: mlia optimization --help", + "Optimization command: " + "mlia optimization --optimization-type pruning " + "--optimization-target 0.5 model.h5", + ], + ], + [ + {"model": "model.h5", "target_profile": "target_profile"}, + {"opt_settings": [OptimizationSettings("pruning", 0.5, None)]}, + [ + "For more info: mlia optimization --help", + "Optimization command: " + "mlia optimization --optimization-type pruning " + "--optimization-target 0.5 " + "--target-profile target_profile model.h5", + ], + ], + ], + ) + def test_apply_optimizations( + args: Dict[str, Any], + params: Dict[str, Any], + expected_result: List[str], + ) -> None: + """Test action resolving for applying optimizations.""" + resolver = CLIActionResolver(args) + assert resolver.apply_optimizations(**params) == expected_result + + @staticmethod + def test_supported_operators_info() -> None: + """Test supported operators info.""" + resolver = CLIActionResolver({}) + assert resolver.supported_operators_info() == [ + "For guidance on supported operators, run: mlia operators " + "--supported-ops-report", + ] + + @staticmethod + def test_operator_compatibility_details() -> None: + """Test operator compatibility details info.""" + resolver = CLIActionResolver({}) + assert resolver.operator_compatibility_details() == [ + "For more details, run: mlia operators --help" + ] + + @staticmethod + def test_optimization_details() -> None: + """Test optimization details info.""" + resolver = CLIActionResolver({}) + assert resolver.optimization_details() == [ + "For more info, see: mlia optimization --help" + ] + + @staticmethod + @pytest.mark.parametrize( + "args, expected_result", + [ + [ + {}, + [], + ], + [ + {"model": "model.tflite"}, + [ + "Check the estimated performance by running the " + "following command: ", + "mlia performance model.tflite", + ], + ], + [ + {"model": "model.tflite", "target_profile": "target_profile"}, + [ + "Check the estimated performance by running the " + "following command: ", + "mlia performance --target-profile target_profile model.tflite", + ], + ], + ], + ) + def test_check_performance( + args: Dict[str, Any], expected_result: List[str] + ) -> None: + """Test check performance info.""" + resolver = CLIActionResolver(args) + assert resolver.check_performance() == expected_result + + @staticmethod + @pytest.mark.parametrize( + "args, expected_result", + [ + [ + {}, + [], + ], + [ + {"model": "model.tflite"}, + [ + "Try running the following command to verify that:", + "mlia operators model.tflite", + ], + ], + [ + {"model": "model.tflite", "target_profile": "target_profile"}, + [ + "Try running the following command to verify that:", + "mlia operators --target-profile target_profile model.tflite", + ], + ], + ], + ) + def test_check_operator_compatibility( + args: Dict[str, Any], expected_result: List[str] + ) -> None: + """Test checking operator compatibility info.""" + resolver = CLIActionResolver(args) + assert resolver.check_operator_compatibility() == expected_result diff --git a/tests/mlia/test_cli_logging.py b/tests/mlia/test_cli_logging.py new file mode 100644 index 0000000..7c5f299 --- /dev/null +++ b/tests/mlia/test_cli_logging.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the module cli.logging.""" +import logging +from pathlib import Path +from typing import Optional + +import pytest + +from mlia.cli.logging import setup_logging +from tests.mlia.utils.logging import clear_loggers + + +def teardown_function() -> None: + """Perform action after test completion. + + This function is launched automatically by pytest after each test + in this module. + """ + clear_loggers() + + +@pytest.mark.parametrize( + "logs_dir, verbose, expected_output, expected_log_file_content", + [ + ( + None, + None, + "cli info\n", + None, + ), + ( + None, + True, + """mlia.tools.aiet_wrapper - aiet debug +cli info +mlia.cli - cli debug +""", + None, + ), + ( + "logs", + True, + """mlia.tools.aiet_wrapper - aiet debug +cli info +mlia.cli - cli debug +""", + """mlia.tools.aiet_wrapper - DEBUG - aiet debug +mlia.cli - DEBUG - cli debug +""", + ), + ], +) +def test_setup_logging( + tmp_path: Path, + capfd: pytest.CaptureFixture, + logs_dir: str, + verbose: bool, + expected_output: str, + expected_log_file_content: str, +) -> None: + """Test function setup_logging.""" + logs_dir_path = tmp_path / logs_dir if logs_dir else None + + setup_logging(logs_dir_path, verbose) + + aiet_logger = logging.getLogger("mlia.tools.aiet_wrapper") + aiet_logger.debug("aiet debug") + + cli_logger = logging.getLogger("mlia.cli") + cli_logger.info("cli info") + cli_logger.debug("cli debug") + + stdout, _ = capfd.readouterr() + assert stdout == expected_output + + check_log_assertions(logs_dir_path, expected_log_file_content) + + +def check_log_assertions( + logs_dir_path: Optional[Path], expected_log_file_content: str +) -> None: + """Test assertions for log file.""" + if logs_dir_path is not None: + assert logs_dir_path.is_dir() + + items = list(logs_dir_path.iterdir()) + assert len(items) == 1 + + log_file_path = items[0] + assert log_file_path.is_file() + + log_file_name = log_file_path.name + assert log_file_name == "mlia.log" + + with open(log_file_path, encoding="utf-8") as log_file: + log_content = log_file.read() + + expected_lines = expected_log_file_content.split("\n") + produced_lines = log_content.split("\n") + + assert len(expected_lines) == len(produced_lines) + for expected, produced in zip(expected_lines, produced_lines): + assert expected in produced diff --git a/tests/mlia/test_cli_main.py b/tests/mlia/test_cli_main.py new file mode 100644 index 0000000..a0937d5 --- /dev/null +++ b/tests/mlia/test_cli_main.py @@ -0,0 +1,357 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for main module.""" +import argparse +from functools import wraps +from pathlib import Path +from typing import Any +from typing import Callable +from typing import List +from unittest.mock import ANY +from unittest.mock import call +from unittest.mock import MagicMock + +import pytest + +import mlia +from mlia.cli.main import CommandInfo +from mlia.cli.main import main +from mlia.core.context import ExecutionContext +from tests.mlia.utils.logging import clear_loggers + + +def teardown_function() -> None: + """Perform action after test completion. + + This function is launched automatically by pytest after each test + in this module. + """ + clear_loggers() + + +def test_option_version(capfd: pytest.CaptureFixture) -> None: + """Test --version.""" + with pytest.raises(SystemExit) as ex: + main(["--version"]) + + assert ex.type == SystemExit + assert ex.value.code == 0 + + stdout, stderr = capfd.readouterr() + assert len(stdout.splitlines()) == 1 + assert stderr == "" + + +@pytest.mark.parametrize( + "is_default, expected_command_help", + [(True, "Test command [default]"), (False, "Test command")], +) +def test_command_info(is_default: bool, expected_command_help: str) -> None: + """Test properties of CommandInfo object.""" + + def test_command() -> None: + """Test command.""" + + command_info = CommandInfo(test_command, ["test"], [], is_default) + assert command_info.command_name == "test_command" + assert command_info.command_name_and_aliases == ["test_command", "test"] + assert command_info.command_help == expected_command_help + + +def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + """Test adding default command.""" + + def mock_command( + func_mock: MagicMock, name: str, with_working_dir: bool + ) -> Callable[..., None]: + """Mock cli command.""" + + def sample_cmd_1(*args: Any, **kwargs: Any) -> None: + """Sample command.""" + func_mock(*args, **kwargs) + + def sample_cmd_2(ctx: ExecutionContext, **kwargs: Any) -> None: + """Another sample command.""" + func_mock(ctx=ctx, **kwargs) + + ret_func = sample_cmd_2 if with_working_dir else sample_cmd_1 + ret_func.__name__ = name + + return ret_func # type: ignore + + default_command = MagicMock() + non_default_command = MagicMock() + + def default_command_params(parser: argparse.ArgumentParser) -> None: + """Add parameters for default command.""" + parser.add_argument("--sample") + parser.add_argument("--default_arg", default="123") + + def non_default_command_params(parser: argparse.ArgumentParser) -> None: + """Add parameters for non default command.""" + parser.add_argument("--param") + + monkeypatch.setattr( + "mlia.cli.main.get_commands", + MagicMock( + return_value=[ + CommandInfo( + func=mock_command(default_command, "default_command", True), + aliases=["command1"], + opt_groups=[default_command_params], + is_default=True, + ), + CommandInfo( + func=mock_command( + non_default_command, "non_default_command", False + ), + aliases=["command2"], + opt_groups=[non_default_command_params], + is_default=False, + ), + ] + ), + ) + + tmp_working_dir = str(tmp_path) + main(["--working-dir", tmp_working_dir, "--sample", "1"]) + main(["command2", "--param", "test"]) + + default_command.assert_called_once_with(ctx=ANY, sample="1", default_arg="123") + non_default_command.assert_called_once_with(param="test") + + +@pytest.mark.parametrize( + "params, expected_call", + [ + [ + ["operators", "sample_model.tflite"], + call( + ctx=ANY, + target_profile="ethos-u55-256", + model="sample_model.tflite", + output=None, + supported_ops_report=False, + ), + ], + [ + ["ops", "sample_model.tflite"], + call( + ctx=ANY, + target_profile="ethos-u55-256", + model="sample_model.tflite", + output=None, + supported_ops_report=False, + ), + ], + [ + ["operators", "sample_model.tflite", "--target-profile", "ethos-u55-128"], + call( + ctx=ANY, + target_profile="ethos-u55-128", + model="sample_model.tflite", + output=None, + supported_ops_report=False, + ), + ], + [ + ["operators"], + call( + ctx=ANY, + target_profile="ethos-u55-256", + model=None, + output=None, + supported_ops_report=False, + ), + ], + [ + ["operators", "--supported-ops-report"], + call( + ctx=ANY, + target_profile="ethos-u55-256", + model=None, + output=None, + supported_ops_report=True, + ), + ], + [ + [ + "all_tests", + "sample_model.h5", + "--optimization-type", + "pruning", + "--optimization-target", + "0.5", + ], + call( + ctx=ANY, + target_profile="ethos-u55-256", + model="sample_model.h5", + optimization_type="pruning", + optimization_target="0.5", + output=None, + evaluate_on=["Vela"], + ), + ], + [ + ["sample_model.h5"], + call( + ctx=ANY, + target_profile="ethos-u55-256", + model="sample_model.h5", + optimization_type="pruning,clustering", + optimization_target="0.5,32", + output=None, + evaluate_on=["Vela"], + ), + ], + [ + ["performance", "sample_model.h5", "--output", "result.json"], + call( + ctx=ANY, + target_profile="ethos-u55-256", + model="sample_model.h5", + output="result.json", + evaluate_on=["Vela"], + ), + ], + [ + ["perf", "sample_model.h5", "--target-profile", "ethos-u55-128"], + call( + ctx=ANY, + target_profile="ethos-u55-128", + model="sample_model.h5", + output=None, + evaluate_on=["Vela"], + ), + ], + [ + ["optimization", "sample_model.h5"], + call( + ctx=ANY, + target_profile="ethos-u55-256", + model="sample_model.h5", + optimization_type="pruning,clustering", + optimization_target="0.5,32", + output=None, + evaluate_on=["Vela"], + ), + ], + [ + ["optimization", "sample_model.h5", "--evaluate-on", "some_backend"], + call( + ctx=ANY, + target_profile="ethos-u55-256", + model="sample_model.h5", + optimization_type="pruning,clustering", + optimization_target="0.5,32", + output=None, + evaluate_on=["some_backend"], + ), + ], + ], +) +def test_commands_execution( + monkeypatch: pytest.MonkeyPatch, params: List[str], expected_call: Any +) -> None: + """Test calling commands from the main function.""" + mock = MagicMock() + + def wrap_mock_command(command: Callable) -> Callable: + """Wrap the command with the mock.""" + + @wraps(command) + def mock_command(*args: Any, **kwargs: Any) -> Any: + """Mock the command.""" + mock(*args, **kwargs) + + return mock_command + + monkeypatch.setattr( + "mlia.cli.options.get_default_backends", MagicMock(return_value=["Vela"]) + ) + + monkeypatch.setattr( + "mlia.cli.options.get_available_backends", + MagicMock(return_value=["Vela", "some_backend"]), + ) + + for command in ["all_tests", "operators", "performance", "optimization"]: + monkeypatch.setattr( + f"mlia.cli.main.{command}", + wrap_mock_command(getattr(mlia.cli.main, command)), + ) + + main(params) + + mock.assert_called_once_with(*expected_call.args, **expected_call.kwargs) + + +@pytest.mark.parametrize( + "verbose, exc_mock, expected_output", + [ + [ + True, + MagicMock(side_effect=Exception("Error")), + [ + "Execution finished with error: Error", + f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} " + "for more details", + ], + ], + [ + False, + MagicMock(side_effect=Exception("Error")), + [ + "Execution finished with error: Error", + f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} " + "for more details, or enable verbose mode", + ], + ], + [ + False, + MagicMock(side_effect=KeyboardInterrupt()), + ["Execution has been interrupted"], + ], + ], +) +def test_verbose_output( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture, + verbose: bool, + exc_mock: MagicMock, + expected_output: List[str], +) -> None: + """Test flag --verbose.""" + + def command_params(parser: argparse.ArgumentParser) -> None: + """Add parameters for non default command.""" + parser.add_argument("--verbose", action="store_true") + + def command() -> None: + """Run test command.""" + exc_mock() + + monkeypatch.setattr( + "mlia.cli.main.get_commands", + MagicMock( + return_value=[ + CommandInfo( + func=command, + aliases=["command"], + opt_groups=[command_params], + ), + ] + ), + ) + + params = ["command"] + if verbose: + params.append("--verbose") + + exit_code = main(params) + assert exit_code == 1 + + stdout, _ = capsys.readouterr() + for expected_message in expected_output: + assert expected_message in stdout diff --git a/tests/mlia/test_cli_options.py b/tests/mlia/test_cli_options.py new file mode 100644 index 0000000..a441e58 --- /dev/null +++ b/tests/mlia/test_cli_options.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module options.""" +import argparse +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import pytest + +from mlia.cli.options import add_output_options +from mlia.cli.options import get_target_profile_opts +from mlia.cli.options import parse_optimization_parameters + + +@pytest.mark.parametrize( + "optimization_type, optimization_target, expected_error, expected_result", + [ + ( + "pruning", + "0.5", + does_not_raise(), + [ + dict( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ) + ], + ), + ( + "clustering", + "32", + does_not_raise(), + [ + dict( + optimization_type="clustering", + optimization_target=32.0, + layers_to_optimize=None, + ) + ], + ), + ( + "pruning,clustering", + "0.5,32", + does_not_raise(), + [ + dict( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ), + dict( + optimization_type="clustering", + optimization_target=32.0, + layers_to_optimize=None, + ), + ], + ), + ( + "pruning, clustering", + "0.5, 32", + does_not_raise(), + [ + dict( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ), + dict( + optimization_type="clustering", + optimization_target=32.0, + layers_to_optimize=None, + ), + ], + ), + ( + "pruning,clustering", + "0.5", + pytest.raises( + Exception, match="Wrong number of optimization targets and types" + ), + None, + ), + ( + "", + "0.5", + pytest.raises(Exception, match="Optimization type is not provided"), + None, + ), + ( + "pruning,clustering", + "", + pytest.raises(Exception, match="Optimization target is not provided"), + None, + ), + ( + "pruning,", + "0.5,abc", + pytest.raises( + Exception, match="Non numeric value for the optimization target" + ), + None, + ), + ], +) +def test_parse_optimization_parameters( + optimization_type: str, + optimization_target: str, + expected_error: Any, + expected_result: Any, +) -> None: + """Test function parse_optimization_parameters.""" + with expected_error: + result = parse_optimization_parameters(optimization_type, optimization_target) + assert result == expected_result + + +@pytest.mark.parametrize( + "args, expected_opts", + [ + [ + {}, + [], + ], + [ + {"target_profile": "profile"}, + ["--target-profile", "profile"], + ], + [ + # for the default profile empty list should be returned + {"target": "ethos-u55-256"}, + [], + ], + ], +) +def test_get_target_opts(args: Optional[Dict], expected_opts: List[str]) -> None: + """Test getting target options.""" + assert get_target_profile_opts(args) == expected_opts + + +@pytest.mark.parametrize( + "output_parameters, expected_path", + [ + [["--output", "report.json"], "report.json"], + [["--output", "REPORT.JSON"], "REPORT.JSON"], + [["--output", "some_folder/report.json"], "some_folder/report.json"], + [["--output", "report.csv"], "report.csv"], + [["--output", "REPORT.CSV"], "REPORT.CSV"], + [["--output", "some_folder/report.csv"], "some_folder/report.csv"], + ], +) +def test_output_options(output_parameters: List[str], expected_path: str) -> None: + """Test output options resolving.""" + parser = argparse.ArgumentParser() + add_output_options(parser) + + args = parser.parse_args(output_parameters) + assert args.output == expected_path + + +@pytest.mark.parametrize( + "output_filename", + [ + "report.txt", + "report.TXT", + "report", + "report.pdf", + ], +) +def test_output_options_bad_parameters( + output_filename: str, capsys: pytest.CaptureFixture +) -> None: + """Test that args parsing should fail if format is not supported.""" + parser = argparse.ArgumentParser() + add_output_options(parser) + + with pytest.raises(SystemExit): + parser.parse_args(["--output", output_filename]) + + err_output = capsys.readouterr().err + suffix = Path(output_filename).suffix[1:] + assert f"Unsupported format '{suffix}'" in err_output diff --git a/tests/mlia/test_core_advice_generation.py b/tests/mlia/test_core_advice_generation.py new file mode 100644 index 0000000..05db698 --- /dev/null +++ b/tests/mlia/test_core_advice_generation.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module advice_generation.""" +from typing import List + +import pytest + +from mlia.core.advice_generation import Advice +from mlia.core.advice_generation import advice_category +from mlia.core.advice_generation import FactBasedAdviceProducer +from mlia.core.common import AdviceCategory +from mlia.core.common import DataItem +from mlia.core.context import Context + + +def test_advice_generation() -> None: + """Test advice generation.""" + + class SampleProducer(FactBasedAdviceProducer): + """Sample producer.""" + + def produce_advice(self, data_item: DataItem) -> None: + """Process data.""" + self.add_advice([f"Advice for {data_item}"]) + + producer = SampleProducer() + producer.produce_advice(123) + producer.produce_advice("hello") + + advice = producer.get_advice() + assert advice == [Advice(["Advice for 123"]), Advice(["Advice for hello"])] + + +@pytest.mark.parametrize( + "category, expected_advice", + [ + [ + AdviceCategory.OPERATORS, + [Advice(["Good advice!"])], + ], + [ + AdviceCategory.PERFORMANCE, + [], + ], + ], +) +def test_advice_category_decorator( + category: AdviceCategory, + expected_advice: List[Advice], + dummy_context: Context, +) -> None: + """Test for advice_category decorator.""" + + class SampleAdviceProducer(FactBasedAdviceProducer): + """Sample advice producer.""" + + @advice_category(AdviceCategory.OPERATORS) + def produce_advice(self, data_item: DataItem) -> None: + """Produce the advice.""" + self.add_advice(["Good advice!"]) + + producer = SampleAdviceProducer() + dummy_context.update( + advice_category=category, event_handlers=[], config_parameters={} + ) + producer.set_context(dummy_context) + + producer.produce_advice("some_data") + advice = producer.get_advice() + + assert advice == expected_advice diff --git a/tests/mlia/test_core_advisor.py b/tests/mlia/test_core_advisor.py new file mode 100644 index 0000000..375ff62 --- /dev/null +++ b/tests/mlia/test_core_advisor.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module advisor.""" +from unittest.mock import MagicMock + +from mlia.core.advisor import InferenceAdvisor +from mlia.core.context import Context +from mlia.core.workflow import WorkflowExecutor + + +def test_inference_advisor_run() -> None: + """Test running sample inference advisor.""" + executor_mock = MagicMock(spec=WorkflowExecutor) + context_mock = MagicMock(spec=Context) + + class SampleAdvisor(InferenceAdvisor): + """Sample inference advisor.""" + + @classmethod + def name(cls) -> str: + """Return name of the advisor.""" + return "sample_advisor" + + @classmethod + def description(cls) -> str: + """Return description of the advisor.""" + return "Sample advisor" + + @classmethod + def info(cls) -> None: + """Print advisor info.""" + + def configure(self, context: Context) -> WorkflowExecutor: + """Configure advisor.""" + return executor_mock + + advisor = SampleAdvisor() + advisor.run(context_mock) + + executor_mock.run.assert_called_once() diff --git a/tests/mlia/test_core_context.py b/tests/mlia/test_core_context.py new file mode 100644 index 0000000..10015aa --- /dev/null +++ b/tests/mlia/test_core_context.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the module context.""" +from pathlib import Path + +from mlia.core.common import AdviceCategory +from mlia.core.context import ExecutionContext +from mlia.core.events import DefaultEventPublisher + + +def test_execution_context(tmpdir: str) -> None: + """Test execution context.""" + publisher = DefaultEventPublisher() + category = AdviceCategory.OPERATORS + + context = ExecutionContext( + advice_category=category, + config_parameters={"param": "value"}, + working_dir=tmpdir, + event_handlers=[], + event_publisher=publisher, + verbose=True, + logs_dir="logs_directory", + models_dir="models_directory", + ) + + assert context.advice_category == category + assert context.config_parameters == {"param": "value"} + assert context.event_handlers == [] + assert context.event_publisher == publisher + assert context.logs_path == Path(tmpdir) / "logs_directory" + expected_model_path = Path(tmpdir) / "models_directory/sample.model" + assert context.get_model_path("sample.model") == expected_model_path + assert context.verbose is True + assert str(context) == ( + f"ExecutionContext: " + f"working_dir={tmpdir}, " + "advice_category=OPERATORS, " + "config_parameters={'param': 'value'}, " + "verbose=True" + ) + + context_with_default_params = ExecutionContext(working_dir=tmpdir) + assert context_with_default_params.advice_category is None + assert context_with_default_params.config_parameters is None + assert context_with_default_params.event_handlers is None + assert isinstance( + context_with_default_params.event_publisher, DefaultEventPublisher + ) + assert context_with_default_params.logs_path == Path(tmpdir) / "logs" + + default_model_path = context_with_default_params.get_model_path("sample.model") + expected_default_model_path = Path(tmpdir) / "models/sample.model" + assert default_model_path == expected_default_model_path + + expected_str = ( + f"ExecutionContext: working_dir={tmpdir}, " + "advice_category=, " + "config_parameters=None, " + "verbose=False" + ) + assert str(context_with_default_params) == expected_str diff --git a/tests/mlia/test_core_data_analysis.py b/tests/mlia/test_core_data_analysis.py new file mode 100644 index 0000000..a782159 --- /dev/null +++ b/tests/mlia/test_core_data_analysis.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module data_analysis.""" +from dataclasses import dataclass + +from mlia.core.common import DataItem +from mlia.core.data_analysis import Fact +from mlia.core.data_analysis import FactExtractor + + +def test_fact_extractor() -> None: + """Test fact extractor.""" + + @dataclass + class SampleFact(Fact): + """Sample fact.""" + + msg: str + + class SampleExtractor(FactExtractor): + """Sample extractor.""" + + def analyze_data(self, data_item: DataItem) -> None: + self.add_fact(SampleFact(f"Fact for {data_item}")) + + extractor = SampleExtractor() + extractor.analyze_data(42) + extractor.analyze_data("some data") + + facts = extractor.get_analyzed_data() + assert facts == [SampleFact("Fact for 42"), SampleFact("Fact for some data")] diff --git a/tests/mlia/test_core_events.py b/tests/mlia/test_core_events.py new file mode 100644 index 0000000..faaab7c --- /dev/null +++ b/tests/mlia/test_core_events.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the module events.""" +from dataclasses import dataclass +from unittest.mock import call +from unittest.mock import MagicMock + +import pytest + +from mlia.core.events import action +from mlia.core.events import ActionFinishedEvent +from mlia.core.events import ActionStartedEvent +from mlia.core.events import DebugEventHandler +from mlia.core.events import DefaultEventPublisher +from mlia.core.events import Event +from mlia.core.events import EventDispatcher +from mlia.core.events import EventHandler +from mlia.core.events import ExecutionFinishedEvent +from mlia.core.events import ExecutionStartedEvent +from mlia.core.events import stage +from mlia.core.events import SystemEventsHandler + + +@dataclass +class SampleEvent(Event): + """Sample event.""" + + msg: str + + +def test_event_publisher() -> None: + """Test event publishing.""" + publisher = DefaultEventPublisher() + handler_mock1 = MagicMock(spec=EventHandler) + handler_mock2 = MagicMock(spec=EventHandler) + + publisher.register_event_handlers([handler_mock1, handler_mock2]) + + event = SampleEvent("hello, event!") + publisher.publish_event(event) + + handler_mock1.handle_event.assert_called_once_with(event) + handler_mock2.handle_event.assert_called_once_with(event) + + +def test_stage_context_manager() -> None: + """Test stage context manager.""" + publisher = DefaultEventPublisher() + + handler_mock = MagicMock(spec=EventHandler) + publisher.register_event_handler(handler_mock) + + events = (SampleEvent("hello"), SampleEvent("goodbye")) + with stage(publisher, events): + print("perform actions") + + assert handler_mock.handle_event.call_count == 2 + calls = [call(event) for event in events] + handler_mock.handle_event.assert_has_calls(calls) + + +def test_action_context_manager() -> None: + """Test action stage context manager.""" + publisher = DefaultEventPublisher() + + handler_mock = MagicMock(spec=EventHandler) + publisher.register_event_handler(handler_mock) + + with action(publisher, "Sample action"): + print("perform actions") + + assert handler_mock.handle_event.call_count == 2 + calls = handler_mock.handle_event.mock_calls + + action_started = calls[0].args[0] + action_finished = calls[1].args[0] + + assert isinstance(action_started, ActionStartedEvent) + assert isinstance(action_finished, ActionFinishedEvent) + + assert action_finished.parent_event_id == action_started.event_id + + +def test_debug_event_handler(capsys: pytest.CaptureFixture) -> None: + """Test debugging event handler.""" + publisher = DefaultEventPublisher() + + publisher.register_event_handler(DebugEventHandler()) + publisher.register_event_handler(DebugEventHandler(with_stacktrace=True)) + + messages = ["Sample event 1", "Sample event 2"] + for message in messages: + publisher.publish_event(SampleEvent(message)) + + captured = capsys.readouterr() + for message in messages: + assert message in captured.out + + assert "traceback.print_stack" in captured.err + + +def test_event_dispatcher(capsys: pytest.CaptureFixture) -> None: + """Test event dispatcher.""" + + class SampleEventHandler(EventDispatcher): + """Sample event handler.""" + + def on_sample_event( # pylint: disable=no-self-use + self, _event: SampleEvent + ) -> None: + """Event handler for SampleEvent.""" + print("Got sample event") + + publisher = DefaultEventPublisher() + publisher.register_event_handler(SampleEventHandler()) + publisher.publish_event(SampleEvent("Sample event")) + + captured = capsys.readouterr() + assert captured.out.strip() == "Got sample event" + + +def test_system_events_handler(capsys: pytest.CaptureFixture) -> None: + """Test system events handler.""" + + class CustomSystemEventHandler(SystemEventsHandler): + """Custom system event handler.""" + + def on_execution_started(self, event: ExecutionStartedEvent) -> None: + """Handle ExecutionStarted event.""" + print("Execution started") + + def on_execution_finished(self, event: ExecutionFinishedEvent) -> None: + """Handle ExecutionFinished event.""" + print("Execution finished") + + publisher = DefaultEventPublisher() + publisher.register_event_handler(CustomSystemEventHandler()) + + publisher.publish_event(ExecutionStartedEvent()) + publisher.publish_event(SampleEvent("Hello world!")) + publisher.publish_event(ExecutionFinishedEvent()) + + captured = capsys.readouterr() + assert captured.out.strip() == "Execution started\nExecution finished" + + +def test_compare_without_id() -> None: + """Test event comparison without event_id.""" + event1 = SampleEvent("message") + event2 = SampleEvent("message") + + assert event1 != event2 + assert event1.compare_without_id(event2) + + assert not event1.compare_without_id("message") # type: ignore diff --git a/tests/mlia/test_core_helpers.py b/tests/mlia/test_core_helpers.py new file mode 100644 index 0000000..8577617 --- /dev/null +++ b/tests/mlia/test_core_helpers.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the helper classes.""" +from mlia.core.helpers import APIActionResolver + + +def test_api_action_resolver() -> None: + """Test APIActionResolver class.""" + helper = APIActionResolver() + + # pylint: disable=use-implicit-booleaness-not-comparison + assert helper.apply_optimizations() == [] + assert helper.supported_operators_info() == [] + assert helper.check_performance() == [] + assert helper.check_operator_compatibility() == [] + assert helper.operator_compatibility_details() == [] + assert helper.optimization_details() == [] diff --git a/tests/mlia/test_core_mixins.py b/tests/mlia/test_core_mixins.py new file mode 100644 index 0000000..d66213d --- /dev/null +++ b/tests/mlia/test_core_mixins.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the module mixins.""" +import pytest + +from mlia.core.common import AdviceCategory +from mlia.core.context import Context +from mlia.core.context import ExecutionContext +from mlia.core.mixins import ContextMixin +from mlia.core.mixins import ParameterResolverMixin + + +def test_context_mixin(dummy_context: Context) -> None: + """Test ContextMixin.""" + + class SampleClass(ContextMixin): + """Sample class.""" + + sample_object = SampleClass() + sample_object.set_context(dummy_context) + assert sample_object.context == dummy_context + + +class TestParameterResolverMixin: + """Tests for parameter resolver mixin.""" + + @staticmethod + def test_parameter_resolver_mixin(dummy_context: ExecutionContext) -> None: + """Test ParameterResolverMixin.""" + + class SampleClass(ParameterResolverMixin): + """Sample class.""" + + def __init__(self) -> None: + """Init sample object.""" + self.context = dummy_context + + self.context.update( + advice_category=AdviceCategory.OPERATORS, + event_handlers=[], + config_parameters={"section": {"param": 123}}, + ) + + sample_object = SampleClass() + value = sample_object.get_parameter("section", "param") + assert value == 123 + + with pytest.raises( + Exception, match="Parameter param expected to have type " + ): + value = sample_object.get_parameter("section", "param", expected_type=str) + + with pytest.raises(Exception, match="Parameter no_param is not set"): + value = sample_object.get_parameter("section", "no_param") + + @staticmethod + def test_parameter_resolver_mixin_no_config( + dummy_context: ExecutionContext, + ) -> None: + """Test ParameterResolverMixin without config params.""" + + class SampleClassNoConfig(ParameterResolverMixin): + """Sample context without config params.""" + + def __init__(self) -> None: + """Init sample object.""" + self.context = dummy_context + + with pytest.raises(Exception, match="Configuration parameters are not set"): + sample_object_no_config = SampleClassNoConfig() + sample_object_no_config.get_parameter("section", "param", expected_type=str) + + @staticmethod + def test_parameter_resolver_mixin_bad_section( + dummy_context: ExecutionContext, + ) -> None: + """Test ParameterResolverMixin without config params.""" + + class SampleClassBadSection(ParameterResolverMixin): + """Sample context with bad section in config.""" + + def __init__(self) -> None: + """Init sample object.""" + self.context = dummy_context + self.context.update( + advice_category=AdviceCategory.OPERATORS, + event_handlers=[], + config_parameters={"section": ["param"]}, + ) + + with pytest.raises( + Exception, + match="Parameter section section has wrong format, " + "expected to be a dictionary", + ): + sample_object_bad_section = SampleClassBadSection() + sample_object_bad_section.get_parameter( + "section", "param", expected_type=str + ) diff --git a/tests/mlia/test_core_performance.py b/tests/mlia/test_core_performance.py new file mode 100644 index 0000000..0d28fe8 --- /dev/null +++ b/tests/mlia/test_core_performance.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the module performance.""" +from pathlib import Path + +from mlia.core.performance import estimate_performance +from mlia.core.performance import PerformanceEstimator + + +def test_estimate_performance(tmp_path: Path) -> None: + """Test function estimate_performance.""" + model_path = tmp_path / "original.tflite" + + class SampleEstimator(PerformanceEstimator[Path, int]): + """Sample estimator.""" + + def estimate(self, model: Path) -> int: + """Estimate performance.""" + if model.name == "original.tflite": + return 1 + + return 2 + + def optimized_model(_original: Path) -> Path: + """Return path to the 'optimized' model.""" + return tmp_path / "optimized.tflite" + + results = estimate_performance(model_path, SampleEstimator(), [optimized_model]) + assert results == [1, 2] diff --git a/tests/mlia/test_core_reporting.py b/tests/mlia/test_core_reporting.py new file mode 100644 index 0000000..2f7ec22 --- /dev/null +++ b/tests/mlia/test_core_reporting.py @@ -0,0 +1,413 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for reporting module.""" +from typing import List + +import pytest + +from mlia.core.reporting import BytesCell +from mlia.core.reporting import Cell +from mlia.core.reporting import ClockCell +from mlia.core.reporting import Column +from mlia.core.reporting import CyclesCell +from mlia.core.reporting import Format +from mlia.core.reporting import NestedReport +from mlia.core.reporting import ReportItem +from mlia.core.reporting import SingleRow +from mlia.core.reporting import Table +from mlia.utils.console import remove_ascii_codes + + +@pytest.mark.parametrize( + "cell, expected_repr", + [ + (BytesCell(None), ""), + (BytesCell(0), "0 bytes"), + (BytesCell(1), "1 byte"), + (BytesCell(100000), "100,000 bytes"), + (ClockCell(None), ""), + (ClockCell(0), "0 Hz"), + (ClockCell(1), "1 Hz"), + (ClockCell(100000), "100,000 Hz"), + (CyclesCell(None), ""), + (CyclesCell(0), "0 cycles"), + (CyclesCell(1), "1 cycle"), + (CyclesCell(100000), "100,000 cycles"), + ], +) +def test_predefined_cell_types(cell: Cell, expected_repr: str) -> None: + """Test predefined cell types.""" + assert str(cell) == expected_repr + + +@pytest.mark.parametrize( + "with_notes, expected_text_report", + [ + [ + True, + """ +Sample table: +┌──────────┬──────────┬──────────┐ +│ Header 1 │ Header 2 │ Header 3 │ +╞══════════╪══════════╪══════════╡ +│ 1 │ 2 │ 3 │ +├──────────┼──────────┼──────────┤ +│ 4 │ 5 │ 123,123 │ +└──────────┴──────────┴──────────┘ +Sample notes + """.strip(), + ], + [ + False, + """ +Sample table: +┌──────────┬──────────┬──────────┐ +│ Header 1 │ Header 2 │ Header 3 │ +╞══════════╪══════════╪══════════╡ +│ 1 │ 2 │ 3 │ +├──────────┼──────────┼──────────┤ +│ 4 │ 5 │ 123,123 │ +└──────────┴──────────┴──────────┘ + """.strip(), + ], + ], +) +def test_table_representation(with_notes: bool, expected_text_report: str) -> None: + """Test table report representation.""" + + def sample_table(with_notes: bool) -> Table: + columns = [ + Column("Header 1", alias="header1", only_for=["plain_text"]), + Column("Header 2", alias="header2", fmt=Format(wrap_width=5)), + Column("Header 3", alias="header3"), + ] + rows = [(1, 2, 3), (4, 5, Cell(123123, fmt=Format(str_fmt="0,d")))] + + return Table( + columns, + rows, + name="Sample table", + alias="sample_table", + notes="Sample notes" if with_notes else None, + ) + + table = sample_table(with_notes) + csv_repr = table.to_csv() + assert csv_repr == [["Header 2", "Header 3"], [2, 3], [5, 123123]] + + json_repr = table.to_json() + assert json_repr == { + "sample_table": [ + {"header2": 2, "header3": 3}, + {"header2": 5, "header3": 123123}, + ] + } + + text_report = remove_ascii_codes(table.to_plain_text()) + assert text_report == expected_text_report + + +def test_csv_nested_table_representation() -> None: + """Test representation of the nested tables in csv format.""" + + def sample_table(num_of_cols: int) -> Table: + columns = [ + Column("Header 1", alias="header1"), + Column("Header 2", alias="header2"), + ] + + rows = [ + ( + 1, + Table( + columns=[ + Column(f"Nested column {i+1}") for i in range(num_of_cols) + ], + rows=[[f"value{i+1}" for i in range(num_of_cols)]], + name="Nested table", + ), + ) + ] + + return Table(columns, rows, name="Sample table", alias="sample_table") + + assert sample_table(num_of_cols=2).to_csv() == [ + ["Header 1", "Header 2"], + [1, "value1;value2"], + ] + + assert sample_table(num_of_cols=1).to_csv() == [ + ["Header 1", "Header 2"], + [1, "value1"], + ] + + assert sample_table(num_of_cols=0).to_csv() == [ + ["Header 1", "Header 2"], + [1, ""], + ] + + +@pytest.mark.parametrize( + "report, expected_plain_text, expected_json_data, expected_csv_data", + [ + ( + NestedReport( + "Sample report", + "sample_report", + [ + ReportItem("Item", "item", "item_value"), + ], + ), + """ +Sample report: + Item item_value +""".strip(), + { + "sample_report": {"item": "item_value"}, + }, + [ + ("item",), + ("item_value",), + ], + ), + ( + NestedReport( + "Sample report", + "sample_report", + [ + ReportItem( + "Item", + "item", + "item_value", + [ReportItem("Nested item", "nested_item", "nested_item_value")], + ), + ], + ), + """ +Sample report: + Item item_value + Nested item nested_item_value +""".strip(), + { + "sample_report": { + "item": {"nested_item": "nested_item_value"}, + }, + }, + [ + ("item", "nested_item"), + ("item_value", "nested_item_value"), + ], + ), + ( + NestedReport( + "Sample report", + "sample_report", + [ + ReportItem( + "Item", + "item", + "item_value", + [ReportItem("Nested item", "nested_item", BytesCell(10))], + ), + ], + ), + """ +Sample report: + Item item_value + Nested item 10 bytes +""".strip(), + { + "sample_report": { + "item": {"nested_item": {"unit": "bytes", "value": 10}}, + }, + }, + [ + ("item", "nested_item_value", "nested_item_unit"), + ("item_value", 10, "bytes"), + ], + ), + ( + NestedReport( + "Sample report", + "sample_report", + [ + ReportItem( + "Item", + "item", + "item_value", + [ + ReportItem( + "Nested item", + "nested_item", + Cell( + 10, fmt=Format(str_fmt=lambda x: f"{x} cell value") + ), + ) + ], + ), + ], + ), + """ +Sample report: + Item item_value + Nested item 10 cell value +""".strip(), + { + "sample_report": { + "item": {"nested_item": 10}, + }, + }, + [ + ("item", "nested_item"), + ("item_value", 10), + ], + ), + ( + NestedReport( + "Sample report", + "sample_report", + [ + ReportItem( + "Item", + "item", + "item_value", + [ + ReportItem( + "Nested item", + "nested_item", + Cell( + 10, fmt=Format(str_fmt=lambda x: f"{x} cell value") + ), + ) + ], + ), + ], + ), + """ +Sample report: + Item item_value + Nested item 10 cell value +""".strip(), + { + "sample_report": { + "item": {"nested_item": 10}, + }, + }, + [ + ("item", "nested_item"), + ("item_value", 10), + ], + ), + ( + NestedReport( + "Sample report", + "sample_report", + [ + ReportItem( + "Item", + "item", + "item_value", + [ + ReportItem("Nested item", "nested_item", Cell(10)), + ], + ), + ], + ), + """ +Sample report: + Item item_value + Nested item 10 +""".strip(), + { + "sample_report": { + "item": {"nested_item": 10}, + }, + }, + [ + ("item", "nested_item"), + ("item_value", 10), + ], + ), + ( + NestedReport( + "Sample report", + "sample_report", + [ + ReportItem( + "Item", + "item", + "item_value", + [ + ReportItem( + "Nested item", "nested_item", Cell(10, fmt=Format()) + ), + ], + ), + ], + ), + """ +Sample report: + Item item_value + Nested item 10 +""".strip(), + { + "sample_report": { + "item": {"nested_item": 10}, + }, + }, + [ + ("item", "nested_item"), + ("item_value", 10), + ], + ), + ], +) +def test_nested_report_representation( + report: NestedReport, + expected_plain_text: str, + expected_json_data: dict, + expected_csv_data: List, +) -> None: + """Test representation of the NestedReport.""" + plain_text = report.to_plain_text() + assert plain_text == expected_plain_text + + json_data = report.to_json() + assert json_data == expected_json_data + + csv_data = report.to_csv() + assert csv_data == expected_csv_data + + +def test_single_row_representation() -> None: + """Test representation of the SingleRow.""" + single_row = SingleRow( + columns=[ + Column("column1", "column1"), + ], + rows=[("value1", "value2")], + name="Single row example", + alias="simple_row_example", + ) + + expected_text = """ +Single row example: + column1 value1 +""".strip() + assert single_row.to_plain_text() == expected_text + assert single_row.to_csv() == [["column1"], ["value1"]] + assert single_row.to_json() == {"simple_row_example": [{"column1": "value1"}]} + + with pytest.raises(Exception, match="Table should have only one row"): + wrong_single_row = SingleRow( + columns=[ + Column("column1", "column1"), + ], + rows=[ + ("value1", "value2"), + ("value1", "value2"), + ], + name="Single row example", + alias="simple_row_example", + ) + wrong_single_row.to_plain_text() diff --git a/tests/mlia/test_core_workflow.py b/tests/mlia/test_core_workflow.py new file mode 100644 index 0000000..470e572 --- /dev/null +++ b/tests/mlia/test_core_workflow.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module workflow.""" +from dataclasses import dataclass +from unittest.mock import call +from unittest.mock import MagicMock + +from mlia.core.advice_generation import Advice +from mlia.core.advice_generation import AdviceEvent +from mlia.core.advice_generation import ContextAwareAdviceProducer +from mlia.core.context import ExecutionContext +from mlia.core.data_analysis import ContextAwareDataAnalyzer +from mlia.core.data_collection import ContextAwareDataCollector +from mlia.core.errors import FunctionalityNotSupportedError +from mlia.core.events import AdviceStageFinishedEvent +from mlia.core.events import AdviceStageStartedEvent +from mlia.core.events import AnalyzedDataEvent +from mlia.core.events import CollectedDataEvent +from mlia.core.events import DataAnalysisStageFinishedEvent +from mlia.core.events import DataAnalysisStageStartedEvent +from mlia.core.events import DataCollectionStageFinishedEvent +from mlia.core.events import DataCollectionStageStartedEvent +from mlia.core.events import DataCollectorSkippedEvent +from mlia.core.events import DefaultEventPublisher +from mlia.core.events import Event +from mlia.core.events import EventHandler +from mlia.core.events import ExecutionFailedEvent +from mlia.core.events import ExecutionFinishedEvent +from mlia.core.events import ExecutionStartedEvent +from mlia.core.workflow import DefaultWorkflowExecutor + + +@dataclass +class SampleEvent(Event): + """Sample event.""" + + msg: str + + +def test_workflow_executor(tmpdir: str) -> None: + """Test workflow executor.""" + handler_mock = MagicMock(spec=EventHandler) + data_collector_mock = MagicMock(spec=ContextAwareDataCollector) + data_collector_mock.collect_data.return_value = 42 + + data_collector_mock_no_value = MagicMock(spec=ContextAwareDataCollector) + data_collector_mock_no_value.collect_data.return_value = None + + data_collector_mock_skipped = MagicMock(spec=ContextAwareDataCollector) + data_collector_mock_skipped.name.return_value = "skipped_collector" + data_collector_mock_skipped.collect_data.side_effect = ( + FunctionalityNotSupportedError("Error!", "Error!") + ) + + data_analyzer_mock = MagicMock(spec=ContextAwareDataAnalyzer) + data_analyzer_mock.get_analyzed_data.return_value = ["Really good number!"] + + advice_producer_mock1 = MagicMock(spec=ContextAwareAdviceProducer) + advice_producer_mock1.get_advice.return_value = Advice(["All good!"]) + + advice_producer_mock2 = MagicMock(spec=ContextAwareAdviceProducer) + advice_producer_mock2.get_advice.return_value = [Advice(["Good advice!"])] + + context = ExecutionContext( + working_dir=tmpdir, + event_handlers=[handler_mock], + event_publisher=DefaultEventPublisher(), + ) + + executor = DefaultWorkflowExecutor( + context, + [ + data_collector_mock, + data_collector_mock_no_value, + data_collector_mock_skipped, + ], + [data_analyzer_mock], + [ + advice_producer_mock1, + advice_producer_mock2, + ], + [SampleEvent("Hello from advisor!")], + ) + + executor.run() + + data_collector_mock.collect_data.assert_called_once() + data_collector_mock_no_value.collect_data.assert_called_once() + data_collector_mock_skipped.collect_data.assert_called_once() + + data_analyzer_mock.analyze_data.assert_called_once_with(42) + + advice_producer_mock1.produce_advice.assert_called_once_with("Really good number!") + advice_producer_mock1.get_advice.assert_called_once() + + advice_producer_mock2.produce_advice.called_once_with("Really good number!") + advice_producer_mock2.get_advice.assert_called_once() + + expected_mock_calls = [ + call(ExecutionStartedEvent()), + call(SampleEvent("Hello from advisor!")), + call(DataCollectionStageStartedEvent()), + call(CollectedDataEvent(data_item=42)), + call(DataCollectorSkippedEvent("skipped_collector", "Error!: Error!")), + call(DataCollectionStageFinishedEvent()), + call(DataAnalysisStageStartedEvent()), + call(AnalyzedDataEvent(data_item="Really good number!")), + call(DataAnalysisStageFinishedEvent()), + call(AdviceStageStartedEvent()), + call(AdviceEvent(advice=Advice(messages=["All good!"]))), + call(AdviceEvent(advice=Advice(messages=["Good advice!"]))), + call(AdviceStageFinishedEvent()), + call(ExecutionFinishedEvent()), + ] + + for expected_call, actual_call in zip( + expected_mock_calls, handler_mock.handle_event.mock_calls + ): + expected_event = expected_call.args[0] + actual_event = actual_call.args[0] + + assert actual_event.compare_without_id(expected_event) + + +def test_workflow_executor_failed(tmpdir: str) -> None: + """Test scenario when one of the components raises exception.""" + handler_mock = MagicMock(spec=EventHandler) + + context = ExecutionContext( + working_dir=tmpdir, + event_handlers=[handler_mock], + event_publisher=DefaultEventPublisher(), + ) + + collection_exception = Exception("Collection failed") + + data_collector_mock = MagicMock(spec=ContextAwareDataCollector) + data_collector_mock.collect_data.side_effect = collection_exception + + executor = DefaultWorkflowExecutor(context, [data_collector_mock], [], []) + executor.run() + + expected_mock_calls = [ + call(ExecutionStartedEvent()), + call(DataCollectionStageStartedEvent()), + call(ExecutionFailedEvent(collection_exception)), + ] + + for expected_call, actual_call in zip( + expected_mock_calls, handler_mock.handle_event.mock_calls + ): + expected_event = expected_call.args[0] + actual_event = actual_call.args[0] + + if isinstance(actual_event, ExecutionFailedEvent): + # seems that dataclass comparison doesn't work well + # for the exceptions + actual_exception = actual_event.err + expected_exception = expected_event.err + + assert actual_exception == expected_exception + continue + + assert actual_event.compare_without_id(expected_event) diff --git a/tests/mlia/test_devices_ethosu_advice_generation.py b/tests/mlia/test_devices_ethosu_advice_generation.py new file mode 100644 index 0000000..98c8a57 --- /dev/null +++ b/tests/mlia/test_devices_ethosu_advice_generation.py @@ -0,0 +1,483 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for Ethos-U advice generation.""" +from typing import List +from typing import Optional + +import pytest + +from mlia.cli.helpers import CLIActionResolver +from mlia.core.advice_generation import Advice +from mlia.core.common import AdviceCategory +from mlia.core.common import DataItem +from mlia.core.context import ExecutionContext +from mlia.core.helpers import ActionResolver +from mlia.core.helpers import APIActionResolver +from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer +from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer +from mlia.devices.ethosu.data_analysis import AllOperatorsSupportedOnNPU +from mlia.devices.ethosu.data_analysis import HasCPUOnlyOperators +from mlia.devices.ethosu.data_analysis import HasUnsupportedOnNPUOperators +from mlia.devices.ethosu.data_analysis import OptimizationDiff +from mlia.devices.ethosu.data_analysis import OptimizationResults +from mlia.devices.ethosu.data_analysis import PerfMetricDiff +from mlia.nn.tensorflow.optimizations.select import OptimizationSettings + + +@pytest.mark.parametrize( + "input_data, advice_category, action_resolver, expected_advice", + [ + [ + AllOperatorsSupportedOnNPU(), + AdviceCategory.OPERATORS, + APIActionResolver(), + [ + Advice( + [ + "You don't have any unsupported operators, your model will " + "run completely on NPU." + ] + ) + ], + ], + [ + AllOperatorsSupportedOnNPU(), + AdviceCategory.OPERATORS, + CLIActionResolver( + { + "target_profile": "sample_target", + "model": "sample_model.tflite", + } + ), + [ + Advice( + [ + "You don't have any unsupported operators, your model will " + "run completely on NPU.", + "Check the estimated performance by running the " + "following command: ", + "mlia performance --target-profile sample_target " + "sample_model.tflite", + ] + ) + ], + ], + [ + HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]), + AdviceCategory.OPERATORS, + APIActionResolver(), + [ + Advice( + [ + "You have at least 3 operators that is CPU only: " + "OP1,OP2,OP3.", + "Using operators that are supported by the NPU will " + "improve performance.", + ] + ) + ], + ], + [ + HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]), + AdviceCategory.OPERATORS, + CLIActionResolver({}), + [ + Advice( + [ + "You have at least 3 operators that is CPU only: " + "OP1,OP2,OP3.", + "Using operators that are supported by the NPU will " + "improve performance.", + "For guidance on supported operators, run: mlia operators " + "--supported-ops-report", + ] + ) + ], + ], + [ + HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4), + AdviceCategory.OPERATORS, + APIActionResolver(), + [ + Advice( + [ + "You have 40% of operators that cannot be placed on the NPU.", + "For better performance, please review the reasons reported " + "in the table, and adjust the model accordingly " + "where possible.", + ] + ) + ], + ], + [ + HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4), + AdviceCategory.OPERATORS, + CLIActionResolver({}), + [ + Advice( + [ + "You have 40% of operators that cannot be placed on the NPU.", + "For better performance, please review the reasons reported " + "in the table, and adjust the model accordingly " + "where possible.", + ] + ) + ], + ], + [ + OptimizationResults( + [ + OptimizationDiff( + opt_type=[OptimizationSettings("pruning", 0.5, None)], + opt_diffs={ + "sram": PerfMetricDiff(100, 150), + "dram": PerfMetricDiff(100, 50), + "on_chip_flash": PerfMetricDiff(100, 100), + "off_chip_flash": PerfMetricDiff(100, 100), + "npu_total_cycles": PerfMetricDiff(10, 5), + }, + ), + ] + ), + AdviceCategory.OPTIMIZATION, + APIActionResolver(), + [ + Advice( + [ + "With the selected optimization (pruning: 0.5)", + "- You have achieved 50.00% performance improvement in " + "DRAM used (KB)", + "- You have achieved 50.00% performance improvement in " + "NPU total cycles", + "- SRAM used (KB) have degraded by 50.00%", + "You can try to push the optimization target higher " + "(e.g. pruning: 0.6) " + "to check if those results can be further improved.", + ] + ), + Advice( + [ + "The applied tooling techniques have an impact " + "on accuracy. Additional hyperparameter tuning may be required " + "after any optimization." + ] + ), + ], + ], + [ + OptimizationResults( + [ + OptimizationDiff( + opt_type=[OptimizationSettings("pruning", 0.5, None)], + opt_diffs={ + "sram": PerfMetricDiff(100, 150), + "dram": PerfMetricDiff(100, 50), + "on_chip_flash": PerfMetricDiff(100, 100), + "off_chip_flash": PerfMetricDiff(100, 100), + "npu_total_cycles": PerfMetricDiff(10, 5), + }, + ), + ] + ), + AdviceCategory.OPTIMIZATION, + CLIActionResolver({"model": "sample_model.h5"}), + [ + Advice( + [ + "With the selected optimization (pruning: 0.5)", + "- You have achieved 50.00% performance improvement in " + "DRAM used (KB)", + "- You have achieved 50.00% performance improvement in " + "NPU total cycles", + "- SRAM used (KB) have degraded by 50.00%", + "You can try to push the optimization target higher " + "(e.g. pruning: 0.6) " + "to check if those results can be further improved.", + "For more info: mlia optimization --help", + "Optimization command: " + "mlia optimization --optimization-type pruning " + "--optimization-target 0.6 sample_model.h5", + ] + ), + Advice( + [ + "The applied tooling techniques have an impact " + "on accuracy. Additional hyperparameter tuning may be required " + "after any optimization." + ] + ), + ], + ], + [ + OptimizationResults( + [ + OptimizationDiff( + opt_type=[ + OptimizationSettings("pruning", 0.5, None), + OptimizationSettings("clustering", 32, None), + ], + opt_diffs={ + "sram": PerfMetricDiff(100, 150), + "dram": PerfMetricDiff(100, 50), + "on_chip_flash": PerfMetricDiff(100, 100), + "off_chip_flash": PerfMetricDiff(100, 100), + "npu_total_cycles": PerfMetricDiff(10, 5), + }, + ), + ] + ), + AdviceCategory.OPTIMIZATION, + APIActionResolver(), + [ + Advice( + [ + "With the selected optimization (pruning: 0.5, clustering: 32)", + "- You have achieved 50.00% performance improvement in " + "DRAM used (KB)", + "- You have achieved 50.00% performance improvement in " + "NPU total cycles", + "- SRAM used (KB) have degraded by 50.00%", + "You can try to push the optimization target higher " + "(e.g. pruning: 0.6 and/or clustering: 16) " + "to check if those results can be further improved.", + ] + ), + Advice( + [ + "The applied tooling techniques have an impact " + "on accuracy. Additional hyperparameter tuning may be required " + "after any optimization." + ] + ), + ], + ], + [ + OptimizationResults( + [ + OptimizationDiff( + opt_type=[ + OptimizationSettings("clustering", 2, None), + ], + opt_diffs={ + "sram": PerfMetricDiff(100, 150), + "dram": PerfMetricDiff(100, 50), + "on_chip_flash": PerfMetricDiff(100, 100), + "off_chip_flash": PerfMetricDiff(100, 100), + "npu_total_cycles": PerfMetricDiff(10, 5), + }, + ), + ] + ), + AdviceCategory.OPTIMIZATION, + APIActionResolver(), + [ + Advice( + [ + "With the selected optimization (clustering: 2)", + "- You have achieved 50.00% performance improvement in " + "DRAM used (KB)", + "- You have achieved 50.00% performance improvement in " + "NPU total cycles", + "- SRAM used (KB) have degraded by 50.00%", + ] + ), + Advice( + [ + "The applied tooling techniques have an impact " + "on accuracy. Additional hyperparameter tuning may be required " + "after any optimization." + ] + ), + ], + ], + [ + OptimizationResults( + [ + OptimizationDiff( + opt_type=[OptimizationSettings("pruning", 0.5, None)], + opt_diffs={ + "sram": PerfMetricDiff(100, 150), + "dram": PerfMetricDiff(100, 150), + "on_chip_flash": PerfMetricDiff(100, 150), + "off_chip_flash": PerfMetricDiff(100, 150), + "npu_total_cycles": PerfMetricDiff(10, 100), + }, + ), + ] + ), + AdviceCategory.OPTIMIZATION, + APIActionResolver(), + [ + Advice( + [ + "With the selected optimization (pruning: 0.5)", + "- DRAM used (KB) have degraded by 50.00%", + "- SRAM used (KB) have degraded by 50.00%", + "- On chip flash used (KB) have degraded by 50.00%", + "- Off chip flash used (KB) have degraded by 50.00%", + "- NPU total cycles have degraded by 900.00%", + "The performance seems to have degraded after " + "applying the selected optimizations, " + "try exploring different optimization types/targets.", + ] + ), + Advice( + [ + "The applied tooling techniques have an impact " + "on accuracy. Additional hyperparameter tuning may be required " + "after any optimization." + ] + ), + ], + ], + [ + OptimizationResults( + [ + OptimizationDiff( + opt_type=[OptimizationSettings("pruning", 0.5, None)], + opt_diffs={ + "sram": PerfMetricDiff(100, 150), + "dram": PerfMetricDiff(100, 150), + "on_chip_flash": PerfMetricDiff(100, 150), + "off_chip_flash": PerfMetricDiff(100, 150), + "npu_total_cycles": PerfMetricDiff(10, 100), + }, + ), + OptimizationDiff( + opt_type=[OptimizationSettings("pruning", 0.6, None)], + opt_diffs={ + "sram": PerfMetricDiff(100, 150), + "dram": PerfMetricDiff(100, 150), + "on_chip_flash": PerfMetricDiff(100, 150), + "off_chip_flash": PerfMetricDiff(100, 150), + "npu_total_cycles": PerfMetricDiff(10, 100), + }, + ), + ] + ), + AdviceCategory.OPTIMIZATION, + APIActionResolver(), + [], # no advice for more than one optimization result + ], + ], +) +def test_ethosu_advice_producer( + tmpdir: str, + input_data: DataItem, + expected_advice: List[Advice], + advice_category: AdviceCategory, + action_resolver: ActionResolver, +) -> None: + """Test Ethos-U Advice producer.""" + producer = EthosUAdviceProducer() + + context = ExecutionContext( + advice_category=advice_category, + working_dir=tmpdir, + action_resolver=action_resolver, + ) + + producer.set_context(context) + producer.produce_advice(input_data) + + assert producer.get_advice() == expected_advice + + +@pytest.mark.parametrize( + "advice_category, action_resolver, expected_advice", + [ + [ + None, + None, + [], + ], + [ + AdviceCategory.OPERATORS, + None, + [], + ], + [ + AdviceCategory.PERFORMANCE, + APIActionResolver(), + [ + Advice( + [ + "You can improve the inference time by using only operators " + "that are supported by the NPU.", + ] + ), + Advice( + [ + "Check if you can improve the performance by applying " + "tooling techniques to your model." + ] + ), + ], + ], + [ + AdviceCategory.PERFORMANCE, + CLIActionResolver({"model": "test_model.h5"}), + [ + Advice( + [ + "You can improve the inference time by using only operators " + "that are supported by the NPU.", + "Try running the following command to verify that:", + "mlia operators test_model.h5", + ] + ), + Advice( + [ + "Check if you can improve the performance by applying " + "tooling techniques to your model.", + "For example: mlia optimization --optimization-type " + "pruning,clustering --optimization-target 0.5,32 " + "test_model.h5", + "For more info: mlia optimization --help", + ] + ), + ], + ], + [ + AdviceCategory.OPTIMIZATION, + APIActionResolver(), + [ + Advice( + [ + "For better performance, make sure that all the operators " + "of your final TFLite model are supported by the NPU.", + ] + ) + ], + ], + [ + AdviceCategory.OPTIMIZATION, + CLIActionResolver({"model": "test_model.h5"}), + [ + Advice( + [ + "For better performance, make sure that all the operators " + "of your final TFLite model are supported by the NPU.", + "For more details, run: mlia operators --help", + ] + ) + ], + ], + ], +) +def test_ethosu_static_advice_producer( + tmpdir: str, + advice_category: Optional[AdviceCategory], + action_resolver: ActionResolver, + expected_advice: List[Advice], +) -> None: + """Test static advice generation.""" + producer = EthosUStaticAdviceProducer() + + context = ExecutionContext( + advice_category=advice_category, + working_dir=tmpdir, + action_resolver=action_resolver, + ) + producer.set_context(context) + assert producer.get_advice() == expected_advice diff --git a/tests/mlia/test_devices_ethosu_advisor.py b/tests/mlia/test_devices_ethosu_advisor.py new file mode 100644 index 0000000..74d2408 --- /dev/null +++ b/tests/mlia/test_devices_ethosu_advisor.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for Ethos-U MLIA module.""" +from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor + + +def test_advisor_metadata() -> None: + """Test advisor metadata.""" + assert EthosUInferenceAdvisor.name() == "ethos_u_inference_advisor" diff --git a/tests/mlia/test_devices_ethosu_config.py b/tests/mlia/test_devices_ethosu_config.py new file mode 100644 index 0000000..49c999a --- /dev/null +++ b/tests/mlia/test_devices_ethosu_config.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for config module.""" +from contextlib import ExitStack as does_not_raise +from typing import Any +from typing import Dict +from unittest.mock import MagicMock + +import pytest + +from mlia.devices.ethosu.config import EthosUConfiguration +from mlia.devices.ethosu.config import get_target +from mlia.tools.vela_wrapper import VelaCompilerOptions +from mlia.utils.filesystem import get_vela_config + + +def test_compiler_options_default_init() -> None: + """Test compiler options default init.""" + opts = VelaCompilerOptions() + + assert opts.config_files is None + assert opts.system_config == "internal-default" + assert opts.memory_mode == "internal-default" + assert opts.accelerator_config is None + assert opts.max_block_dependency == 3 + assert opts.arena_cache_size is None + assert opts.tensor_allocator == "HillClimb" + assert opts.cpu_tensor_alignment == 16 + assert opts.optimization_strategy == "Performance" + assert opts.output_dir is None + + +def test_ethosu_target() -> None: + """Test Ethos-U target configuration init.""" + default_config = EthosUConfiguration("ethos-u55-256") + + assert default_config.target == "ethos-u55" + assert default_config.mac == 256 + assert default_config.compiler_options is not None + + +def test_get_target() -> None: + """Test function get_target.""" + with pytest.raises(Exception, match="No target profile given"): + get_target(None) # type: ignore + + with pytest.raises(Exception, match="Unable to find target profile unknown"): + get_target("unknown") + + u65_device = get_target("ethos-u65-512") + + assert isinstance(u65_device, EthosUConfiguration) + assert u65_device.target == "ethos-u65" + assert u65_device.mac == 512 + assert u65_device.compiler_options.accelerator_config == "ethos-u65-512" + assert u65_device.compiler_options.memory_mode == "Dedicated_Sram" + assert u65_device.compiler_options.config_files == str(get_vela_config()) + + +@pytest.mark.parametrize( + "profile_data, expected_error", + [ + [ + {}, + pytest.raises( + Exception, + match="Mandatory fields missing from target profile: " + r"\['mac', 'memory_mode', 'system_config', 'target'\]", + ), + ], + [ + {"target": "ethos-u65", "mac": 512}, + pytest.raises( + Exception, + match="Mandatory fields missing from target profile: " + r"\['memory_mode', 'system_config'\]", + ), + ], + [ + { + "target": "ethos-u65", + "mac": 2, + "system_config": "Ethos_U65_Embedded", + "memory_mode": "Shared_Sram", + }, + pytest.raises( + Exception, + match=r"Mac value for selected device should be in \[256, 512\]", + ), + ], + [ + { + "target": "ethos-u55", + "mac": 1, + "system_config": "Ethos_U55_High_End_Embedded", + "memory_mode": "Shared_Sram", + }, + pytest.raises( + Exception, + match="Mac value for selected device should be " + r"in \[32, 64, 128, 256\]", + ), + ], + [ + { + "target": "ethos-u65", + "mac": 512, + "system_config": "Ethos_U65_Embedded", + "memory_mode": "Shared_Sram", + }, + does_not_raise(), + ], + ], +) +def test_ethosu_configuration( + monkeypatch: pytest.MonkeyPatch, profile_data: Dict[str, Any], expected_error: Any +) -> None: + """Test creating Ethos-U configuration.""" + monkeypatch.setattr( + "mlia.devices.ethosu.config.get_profile", MagicMock(return_value=profile_data) + ) + + with expected_error: + EthosUConfiguration("target") diff --git a/tests/mlia/test_devices_ethosu_data_analysis.py b/tests/mlia/test_devices_ethosu_data_analysis.py new file mode 100644 index 0000000..4b1d38b --- /dev/null +++ b/tests/mlia/test_devices_ethosu_data_analysis.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for Ethos-U data analysis module.""" +from typing import List + +import pytest + +from mlia.core.common import DataItem +from mlia.core.data_analysis import Fact +from mlia.devices.ethosu.config import EthosUConfiguration +from mlia.devices.ethosu.data_analysis import AllOperatorsSupportedOnNPU +from mlia.devices.ethosu.data_analysis import EthosUDataAnalyzer +from mlia.devices.ethosu.data_analysis import HasCPUOnlyOperators +from mlia.devices.ethosu.data_analysis import HasUnsupportedOnNPUOperators +from mlia.devices.ethosu.data_analysis import OptimizationDiff +from mlia.devices.ethosu.data_analysis import OptimizationResults +from mlia.devices.ethosu.data_analysis import PerfMetricDiff +from mlia.devices.ethosu.performance import MemoryUsage +from mlia.devices.ethosu.performance import NPUCycles +from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics +from mlia.devices.ethosu.performance import PerformanceMetrics +from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.tools.vela_wrapper import NpuSupported +from mlia.tools.vela_wrapper import Operator +from mlia.tools.vela_wrapper import Operators + + +def test_perf_metrics_diff() -> None: + """Test PerfMetricsDiff class.""" + diff_same = PerfMetricDiff(1, 1) + assert diff_same.same is True + assert diff_same.improved is False + assert diff_same.degraded is False + assert diff_same.diff == 0 + + diff_improved = PerfMetricDiff(10, 5) + assert diff_improved.same is False + assert diff_improved.improved is True + assert diff_improved.degraded is False + assert diff_improved.diff == 50.0 + + diff_degraded = PerfMetricDiff(5, 10) + assert diff_degraded.same is False + assert diff_degraded.improved is False + assert diff_degraded.degraded is True + assert diff_degraded.diff == -100.0 + + diff_original_zero = PerfMetricDiff(0, 1) + assert diff_original_zero.diff == 0 + + +@pytest.mark.parametrize( + "input_data, expected_facts", + [ + [ + Operators( + [ + Operator( + "CPU operator", + "CPU operator type", + NpuSupported(False, [("CPU only operator", "")]), + ) + ] + ), + [ + HasCPUOnlyOperators(["CPU operator type"]), + HasUnsupportedOnNPUOperators(1.0), + ], + ], + [ + Operators( + [ + Operator( + "NPU operator", + "NPU operator type", + NpuSupported(True, []), + ) + ] + ), + [ + AllOperatorsSupportedOnNPU(), + ], + ], + [ + OptimizationPerformanceMetrics( + PerformanceMetrics( + EthosUConfiguration("ethos-u55-256"), + NPUCycles(1, 2, 3, 4, 5, 6), + # memory metrics are in kilobytes + MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore + ), + [ + [ + [ + OptimizationSettings("pruning", 0.5, None), + ], + PerformanceMetrics( + EthosUConfiguration("ethos-u55-256"), + NPUCycles(1, 2, 3, 4, 5, 6), + # memory metrics are in kilobytes + MemoryUsage( + *[i * 1024 for i in range(1, 6)] # type: ignore + ), + ), + ], + ], + ), + [ + OptimizationResults( + [ + OptimizationDiff( + opt_type=[ + OptimizationSettings("pruning", 0.5, None), + ], + opt_diffs={ + "sram": PerfMetricDiff(1.0, 1.0), + "dram": PerfMetricDiff(2.0, 2.0), + "on_chip_flash": PerfMetricDiff(4.0, 4.0), + "off_chip_flash": PerfMetricDiff(5.0, 5.0), + "npu_total_cycles": PerfMetricDiff(3, 3), + }, + ) + ] + ) + ], + ], + [ + OptimizationPerformanceMetrics( + PerformanceMetrics( + EthosUConfiguration("ethos-u55-256"), + NPUCycles(1, 2, 3, 4, 5, 6), + # memory metrics are in kilobytes + MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore + ), + [], + ), + [], + ], + ], +) +def test_ethos_u_data_analyzer( + input_data: DataItem, expected_facts: List[Fact] +) -> None: + """Test Ethos-U data analyzer.""" + analyzer = EthosUDataAnalyzer() + analyzer.analyze_data(input_data) + assert analyzer.get_analyzed_data() == expected_facts diff --git a/tests/mlia/test_devices_ethosu_data_collection.py b/tests/mlia/test_devices_ethosu_data_collection.py new file mode 100644 index 0000000..897cf41 --- /dev/null +++ b/tests/mlia/test_devices_ethosu_data_collection.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the data collection module for Ethos-U.""" +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from mlia.core.context import Context +from mlia.core.data_collection import DataCollector +from mlia.core.errors import FunctionalityNotSupportedError +from mlia.devices.ethosu.config import EthosUConfiguration +from mlia.devices.ethosu.data_collection import EthosUOperatorCompatibility +from mlia.devices.ethosu.data_collection import EthosUOptimizationPerformance +from mlia.devices.ethosu.data_collection import EthosUPerformance +from mlia.devices.ethosu.performance import MemoryUsage +from mlia.devices.ethosu.performance import NPUCycles +from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics +from mlia.devices.ethosu.performance import PerformanceMetrics +from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.tools.vela_wrapper import Operators + + +@pytest.mark.parametrize( + "collector, expected_name", + [ + ( + EthosUOperatorCompatibility, + "ethos_u_operator_compatibility", + ), + ( + EthosUPerformance, + "ethos_u_performance", + ), + ( + EthosUOptimizationPerformance, + "ethos_u_model_optimizations", + ), + ], +) +def test_collectors_metadata( + collector: DataCollector, + expected_name: str, +) -> None: + """Test collectors metadata.""" + assert collector.name() == expected_name + + +def test_operator_compatibility_collector( + dummy_context: Context, test_tflite_model: Path +) -> None: + """Test operator compatibility data collector.""" + device = EthosUConfiguration("ethos-u55-256") + + collector = EthosUOperatorCompatibility(test_tflite_model, device) + collector.set_context(dummy_context) + + result = collector.collect_data() + assert isinstance(result, Operators) + + +def test_performance_collector( + monkeypatch: pytest.MonkeyPatch, dummy_context: Context, test_tflite_model: Path +) -> None: + """Test performance data collector.""" + device = EthosUConfiguration("ethos-u55-256") + + mock_performance_estimation(monkeypatch, device) + + collector = EthosUPerformance(test_tflite_model, device) + collector.set_context(dummy_context) + + result = collector.collect_data() + assert isinstance(result, PerformanceMetrics) + + +def test_optimization_performance_collector( + monkeypatch: pytest.MonkeyPatch, + dummy_context: Context, + test_keras_model: Path, + test_tflite_model: Path, +) -> None: + """Test optimization performance data collector.""" + device = EthosUConfiguration("ethos-u55-256") + + mock_performance_estimation(monkeypatch, device) + collector = EthosUOptimizationPerformance( + test_keras_model, + device, + [ + [ + {"optimization_type": "pruning", "optimization_target": 0.5}, + ] + ], + ) + collector.set_context(dummy_context) + result = collector.collect_data() + + assert isinstance(result, OptimizationPerformanceMetrics) + assert isinstance(result.original_perf_metrics, PerformanceMetrics) + assert isinstance(result.optimizations_perf_metrics, list) + assert len(result.optimizations_perf_metrics) == 1 + + opt, metrics = result.optimizations_perf_metrics[0] + assert opt == [OptimizationSettings("pruning", 0.5, None)] + assert isinstance(metrics, PerformanceMetrics) + + collector_no_optimizations = EthosUOptimizationPerformance( + test_keras_model, + device, + [], + ) + with pytest.raises(FunctionalityNotSupportedError): + collector_no_optimizations.collect_data() + + collector_tflite = EthosUOptimizationPerformance( + test_tflite_model, + device, + [ + [ + {"optimization_type": "pruning", "optimization_target": 0.5}, + ] + ], + ) + collector_tflite.set_context(dummy_context) + with pytest.raises(FunctionalityNotSupportedError): + collector_tflite.collect_data() + + with pytest.raises( + Exception, match="Optimization parameters expected to be a list" + ): + collector_bad_config = EthosUOptimizationPerformance( + test_keras_model, device, {"optimization_type": "pruning"} # type: ignore + ) + collector.set_context(dummy_context) + collector_bad_config.collect_data() + + +def mock_performance_estimation( + monkeypatch: pytest.MonkeyPatch, device: EthosUConfiguration +) -> None: + """Mock performance estimation.""" + metrics = PerformanceMetrics( + device, + NPUCycles(1, 2, 3, 4, 5, 6), + MemoryUsage(1, 2, 3, 4, 5), + ) + monkeypatch.setattr( + "mlia.devices.ethosu.data_collection.EthosUPerformanceEstimator.estimate", + MagicMock(return_value=metrics), + ) diff --git a/tests/mlia/test_devices_ethosu_performance.py b/tests/mlia/test_devices_ethosu_performance.py new file mode 100644 index 0000000..e27efa0 --- /dev/null +++ b/tests/mlia/test_devices_ethosu_performance.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Performance estimation tests.""" +from unittest.mock import MagicMock + +import pytest + +from mlia.devices.ethosu.performance import MemorySizeType +from mlia.devices.ethosu.performance import MemoryUsage + + +def test_memory_usage_conversion() -> None: + """Test MemoryUsage objects conversion.""" + memory_usage_in_kb = MemoryUsage(1, 2, 3, 4, 5, MemorySizeType.KILOBYTES) + assert memory_usage_in_kb.in_kilobytes() == memory_usage_in_kb + + memory_usage_in_bytes = MemoryUsage( + 1 * 1024, 2 * 1024, 3 * 1024, 4 * 1024, 5 * 1024 + ) + assert memory_usage_in_bytes.in_kilobytes() == memory_usage_in_kb + + +def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None: + """Mock performance estimation.""" + monkeypatch.setattr( + "mlia.tools.aiet_wrapper.estimate_performance", + MagicMock(return_value=MagicMock()), + ) diff --git a/tests/mlia/test_devices_ethosu_reporters.py b/tests/mlia/test_devices_ethosu_reporters.py new file mode 100644 index 0000000..2d5905c --- /dev/null +++ b/tests/mlia/test_devices_ethosu_reporters.py @@ -0,0 +1,434 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for reports module.""" +import json +import sys +from contextlib import ExitStack as doesnt_raise +from pathlib import Path +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Literal + +import pytest + +from mlia.core.reporting import get_reporter +from mlia.core.reporting import produce_report +from mlia.core.reporting import Report +from mlia.core.reporting import Reporter +from mlia.core.reporting import Table +from mlia.devices.ethosu.config import EthosUConfiguration +from mlia.devices.ethosu.performance import MemoryUsage +from mlia.devices.ethosu.performance import NPUCycles +from mlia.devices.ethosu.performance import PerformanceMetrics +from mlia.devices.ethosu.reporters import find_appropriate_formatter +from mlia.devices.ethosu.reporters import report_device_details +from mlia.devices.ethosu.reporters import report_operators +from mlia.devices.ethosu.reporters import report_perf_metrics +from mlia.tools.vela_wrapper import NpuSupported +from mlia.tools.vela_wrapper import Operator +from mlia.tools.vela_wrapper import Operators +from mlia.utils.console import remove_ascii_codes + + +@pytest.mark.parametrize( + "data, formatters", + [ + ( + [Operator("test_operator", "test_type", NpuSupported(False, []))], + [report_operators], + ), + ( + PerformanceMetrics( + EthosUConfiguration("ethos-u55-256"), + NPUCycles(0, 0, 0, 0, 0, 0), + MemoryUsage(0, 0, 0, 0, 0), + ), + [report_perf_metrics], + ), + ], +) +@pytest.mark.parametrize( + "fmt, output, expected_error", + [ + [ + "unknown_format", + sys.stdout, + pytest.raises(Exception, match="Unknown format unknown_format"), + ], + [ + "plain_text", + sys.stdout, + doesnt_raise(), + ], + [ + "json", + sys.stdout, + doesnt_raise(), + ], + [ + "csv", + sys.stdout, + doesnt_raise(), + ], + [ + "plain_text", + "report.txt", + doesnt_raise(), + ], + [ + "json", + "report.json", + doesnt_raise(), + ], + [ + "csv", + "report.csv", + doesnt_raise(), + ], + ], +) +def test_report( + data: Any, + formatters: List[Callable], + fmt: Literal["plain_text", "json", "csv"], + output: Any, + expected_error: Any, + tmp_path: Path, +) -> None: + """Test report function.""" + if is_file := isinstance(output, str): + output = tmp_path / output + + for formatter in formatters: + with expected_error: + produce_report(data, formatter, fmt, output) + + if is_file: + assert output.is_file() + assert output.stat().st_size > 0 + + +@pytest.mark.parametrize( + "ops, expected_plain_text, expected_json_dict, expected_csv_list", + [ + ( + [ + Operator( + "npu_supported", + "test_type", + NpuSupported(True, []), + ), + Operator( + "cpu_only", + "test_type", + NpuSupported( + False, + [ + ( + "CPU only operator", + "", + ), + ], + ), + ), + Operator( + "npu_unsupported", + "test_type", + NpuSupported( + False, + [ + ( + "Not supported operator", + "Reason why operator is not supported", + ) + ], + ), + ), + ], + """ +Operators: +┌───┬─────────────────┬───────────────┬───────────┬───────────────────────────────┐ +│ # │ Operator name │ Operator type │ Placement │ Notes │ +╞═══╪═════════════════╪═══════════════╪═══════════╪═══════════════════════════════╡ +│ 1 │ npu_supported │ test_type │ NPU │ │ +├───┼─────────────────┼───────────────┼───────────┼───────────────────────────────┤ +│ 2 │ cpu_only │ test_type │ CPU │ * CPU only operator │ +├───┼─────────────────┼───────────────┼───────────┼───────────────────────────────┤ +│ 3 │ npu_unsupported │ test_type │ CPU │ * Not supported operator │ +│ │ │ │ │ │ +│ │ │ │ │ * Reason why operator is not │ +│ │ │ │ │ supported │ +└───┴─────────────────┴───────────────┴───────────┴───────────────────────────────┘ +""".strip(), + { + "operators": [ + { + "operator_name": "npu_supported", + "operator_type": "test_type", + "placement": "NPU", + "notes": [], + }, + { + "operator_name": "cpu_only", + "operator_type": "test_type", + "placement": "CPU", + "notes": [{"note": "CPU only operator"}], + }, + { + "operator_name": "npu_unsupported", + "operator_type": "test_type", + "placement": "CPU", + "notes": [ + {"note": "Not supported operator"}, + {"note": "Reason why operator is not supported"}, + ], + }, + ] + }, + [ + ["Operator name", "Operator type", "Placement", "Notes"], + ["npu_supported", "test_type", "NPU", ""], + ["cpu_only", "test_type", "CPU", "CPU only operator"], + [ + "npu_unsupported", + "test_type", + "CPU", + "Not supported operator;Reason why operator is not supported", + ], + ], + ), + ], +) +def test_report_operators( + ops: List[Operator], + expected_plain_text: str, + expected_json_dict: Dict, + expected_csv_list: List, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test report_operatos formatter.""" + # make terminal wide enough to print whole table + monkeypatch.setenv("COLUMNS", "100") + + report = report_operators(ops) + assert isinstance(report, Table) + + plain_text = remove_ascii_codes(report.to_plain_text()) + assert plain_text == expected_plain_text + + json_dict = report.to_json() + assert json_dict == expected_json_dict + + csv_list = report.to_csv() + assert csv_list == expected_csv_list + + +@pytest.mark.parametrize( + "device, expected_plain_text, expected_json_dict, expected_csv_list", + [ + [ + EthosUConfiguration("ethos-u55-256"), + """Device information: + Target ethos-u55 + MAC 256 + + Memory mode Shared_Sram + Const mem area Axi1 + Arena mem area Axi0 + Cache mem area Axi0 + Arena cache size 4,294,967,296 bytes + + System config Ethos_U55_High_End_Embedded + Accelerator clock 500,000,000 Hz + AXI0 port Sram + AXI1 port OffChipFlash + + Memory area settings: + Sram: + Clock scales 1.0 + Burst length 32 bytes + Read latency 32 cycles + Write latency 32 cycles + + Dram: + Clock scales 1.0 + Burst length 1 byte + Read latency 0 cycles + Write latency 0 cycles + + OnChipFlash: + Clock scales 1.0 + Burst length 1 byte + Read latency 0 cycles + Write latency 0 cycles + + OffChipFlash: + Clock scales 0.125 + Burst length 128 bytes + Read latency 64 cycles + Write latency 64 cycles + + Architecture settings: + Permanent storage mem area OffChipFlash + Feature map storage mem area Sram + Fast storage mem area Sram""", + { + "device": { + "target": "ethos-u55", + "mac": 256, + "memory_mode": { + "const_mem_area": "Axi1", + "arena_mem_area": "Axi0", + "cache_mem_area": "Axi0", + "arena_cache_size": {"value": 4294967296, "unit": "bytes"}, + }, + "system_config": { + "accelerator_clock": {"value": 500000000.0, "unit": "Hz"}, + "axi0_port": "Sram", + "axi1_port": "OffChipFlash", + "memory_area": { + "Sram": { + "clock_scales": 1.0, + "burst_length": {"value": 32, "unit": "bytes"}, + "read_latency": {"value": 32, "unit": "cycles"}, + "write_latency": {"value": 32, "unit": "cycles"}, + }, + "Dram": { + "clock_scales": 1.0, + "burst_length": {"value": 1, "unit": "byte"}, + "read_latency": {"value": 0, "unit": "cycles"}, + "write_latency": {"value": 0, "unit": "cycles"}, + }, + "OnChipFlash": { + "clock_scales": 1.0, + "burst_length": {"value": 1, "unit": "byte"}, + "read_latency": {"value": 0, "unit": "cycles"}, + "write_latency": {"value": 0, "unit": "cycles"}, + }, + "OffChipFlash": { + "clock_scales": 0.125, + "burst_length": {"value": 128, "unit": "bytes"}, + "read_latency": {"value": 64, "unit": "cycles"}, + "write_latency": {"value": 64, "unit": "cycles"}, + }, + }, + }, + "arch_settings": { + "permanent_storage_mem_area": "OffChipFlash", + "feature_map_storage_mem_area": "Sram", + "fast_storage_mem_area": "Sram", + }, + } + }, + [ + ( + "target", + "mac", + "memory_mode", + "const_mem_area", + "arena_mem_area", + "cache_mem_area", + "arena_cache_size_value", + "arena_cache_size_unit", + "system_config", + "accelerator_clock_value", + "accelerator_clock_unit", + "axi0_port", + "axi1_port", + "clock_scales", + "burst_length_value", + "burst_length_unit", + "read_latency_value", + "read_latency_unit", + "write_latency_value", + "write_latency_unit", + "permanent_storage_mem_area", + "feature_map_storage_mem_area", + "fast_storage_mem_area", + ), + ( + "ethos-u55", + 256, + "Shared_Sram", + "Axi1", + "Axi0", + "Axi0", + 4294967296, + "bytes", + "Ethos_U55_High_End_Embedded", + 500000000.0, + "Hz", + "Sram", + "OffChipFlash", + 0.125, + 128, + "bytes", + 64, + "cycles", + 64, + "cycles", + "OffChipFlash", + "Sram", + "Sram", + ), + ], + ], + ], +) +def test_report_device_details( + device: EthosUConfiguration, + expected_plain_text: str, + expected_json_dict: Dict, + expected_csv_list: List, +) -> None: + """Test report_operatos formatter.""" + report = report_device_details(device) + assert isinstance(report, Report) + + plain_text = report.to_plain_text() + assert plain_text == expected_plain_text + + json_dict = report.to_json() + assert json_dict == expected_json_dict + + csv_list = report.to_csv() + assert csv_list == expected_csv_list + + +def test_get_reporter(tmp_path: Path) -> None: + """Test reporter functionality.""" + ops = Operators( + [ + Operator( + "npu_supported", + "op_type", + NpuSupported(True, []), + ), + ] + ) + + output = tmp_path / "output.json" + with get_reporter("json", output, find_appropriate_formatter) as reporter: + assert isinstance(reporter, Reporter) + + with pytest.raises( + Exception, match="Unable to find appropriate formatter for some_data" + ): + reporter.submit("some_data") + + reporter.submit(ops) + + with open(output, encoding="utf-8") as file: + json_data = json.load(file) + + assert json_data == { + "operators_stats": [ + { + "npu_unsupported_ratio": 0.0, + "num_of_npu_supported_operators": 1, + "num_of_operators": 1, + } + ] + } diff --git a/tests/mlia/test_nn_tensorflow_config.py b/tests/mlia/test_nn_tensorflow_config.py new file mode 100644 index 0000000..1ac9f97 --- /dev/null +++ b/tests/mlia/test_nn_tensorflow_config.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for config module.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any + +import pytest + +from mlia.nn.tensorflow.config import get_model +from mlia.nn.tensorflow.config import KerasModel +from mlia.nn.tensorflow.config import TFLiteModel +from mlia.nn.tensorflow.config import TfModel + + +def test_convert_keras_to_tflite(tmp_path: Path, test_keras_model: Path) -> None: + """Test Keras to TFLite conversion.""" + keras_model = KerasModel(test_keras_model) + + tflite_model_path = tmp_path / "test.tflite" + keras_model.convert_to_tflite(tflite_model_path) + + assert tflite_model_path.is_file() + assert tflite_model_path.stat().st_size > 0 + + +def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None: + """Test TensorFlow saved model to TFLite conversion.""" + tf_model = TfModel(test_tf_model) + + tflite_model_path = tmp_path / "test.tflite" + tf_model.convert_to_tflite(tflite_model_path) + + assert tflite_model_path.is_file() + assert tflite_model_path.stat().st_size > 0 + + +@pytest.mark.parametrize( + "model_path, expected_type, expected_error", + [ + ("test.tflite", TFLiteModel, does_not_raise()), + ("test.h5", KerasModel, does_not_raise()), + ("test.hdf5", KerasModel, does_not_raise()), + ( + "test.model", + None, + pytest.raises( + Exception, + match="The input model format is not supported" + r"\(supported formats: TFLite, Keras, TensorFlow saved model\)!", + ), + ), + ], +) +def test_get_model_file( + model_path: str, expected_type: type, expected_error: Any +) -> None: + """Test TFLite model type.""" + with expected_error: + model = get_model(model_path) + assert isinstance(model, expected_type) + + +@pytest.mark.parametrize( + "model_path, expected_type", [("tf_model_test_model", TfModel)] +) +def test_get_model_dir( + test_models_path: Path, model_path: str, expected_type: type +) -> None: + """Test TFLite model type.""" + model = get_model(str(test_models_path / model_path)) + assert isinstance(model, expected_type) diff --git a/tests/mlia/test_nn_tensorflow_optimizations_clustering.py b/tests/mlia/test_nn_tensorflow_optimizations_clustering.py new file mode 100644 index 0000000..9bcf918 --- /dev/null +++ b/tests/mlia/test_nn_tensorflow_optimizations_clustering.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Test for module optimizations/clustering.""" +from pathlib import Path +from typing import List +from typing import Optional + +import pytest +import tensorflow as tf + +from mlia.nn.tensorflow.optimizations.clustering import Clusterer +from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration +from mlia.nn.tensorflow.optimizations.pruning import Pruner +from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration +from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode +from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics +from mlia.nn.tensorflow.utils import convert_to_tflite +from mlia.nn.tensorflow.utils import save_tflite_model +from tests.mlia.utils.common import get_dataset +from tests.mlia.utils.common import train_model + + +def _prune_model( + model: tf.keras.Model, target_sparsity: float, layers_to_prune: Optional[List[str]] +) -> tf.keras.Model: + x_train, y_train = get_dataset() + batch_size = 1 + num_epochs = 1 + + pruner = Pruner( + model, + PruningConfiguration( + target_sparsity, + layers_to_prune, + x_train, + y_train, + batch_size, + num_epochs, + ), + ) + pruner.apply_optimization() + pruned_model = pruner.get_model() + + return pruned_model + + +def _test_num_unique_weights( + metrics: TFLiteMetrics, + target_num_clusters: int, + layers_to_cluster: Optional[List[str]], +) -> None: + clustered_uniqueness_dict = metrics.num_unique_weights( + ReportClusterMode.NUM_CLUSTERS_PER_AXIS + ) + num_clustered_layers = 0 + num_optimizable_layers = len(clustered_uniqueness_dict) + if layers_to_cluster: + expected_num_clustered_layers = len(layers_to_cluster) + else: + expected_num_clustered_layers = num_optimizable_layers + for layer_name in clustered_uniqueness_dict: + # the +1 is there temporarily because of a bug that's been fixed + # but the fix hasn't been merged yet. + # Will need to be removed in the future. + if clustered_uniqueness_dict[layer_name][0] <= (target_num_clusters + 1): + num_clustered_layers = num_clustered_layers + 1 + # make sure we are having exactly as many clustered layers as we wanted + assert num_clustered_layers == expected_num_clustered_layers + + +def _test_sparsity( + metrics: TFLiteMetrics, + target_sparsity: float, + layers_to_cluster: Optional[List[str]], +) -> None: + pruned_sparsity_dict = metrics.sparsity_per_layer() + num_sparse_layers = 0 + num_optimizable_layers = len(pruned_sparsity_dict) + error_margin = 0.03 + if layers_to_cluster: + expected_num_sparse_layers = len(layers_to_cluster) + else: + expected_num_sparse_layers = num_optimizable_layers + for layer_name in pruned_sparsity_dict: + if abs(pruned_sparsity_dict[layer_name] - target_sparsity) < error_margin: + num_sparse_layers = num_sparse_layers + 1 + # make sure we are having exactly as many sparse layers as we wanted + assert num_sparse_layers == expected_num_sparse_layers + + +@pytest.mark.skip(reason="Test fails randomly, further investigation is needed") +@pytest.mark.parametrize("target_num_clusters", (32, 4)) +@pytest.mark.parametrize("sparsity_aware", (False, True)) +@pytest.mark.parametrize("layers_to_cluster", (["conv1"], ["conv1", "conv2"], None)) +def test_cluster_simple_model_fully( + target_num_clusters: int, + sparsity_aware: bool, + layers_to_cluster: Optional[List[str]], + tmp_path: Path, + test_keras_model: Path, +) -> None: + """Simple MNIST test to see if clustering works correctly.""" + target_sparsity = 0.5 + + base_model = tf.keras.models.load_model(str(test_keras_model)) + train_model(base_model) + + if sparsity_aware: + base_model = _prune_model(base_model, target_sparsity, layers_to_cluster) + + clusterer = Clusterer( + base_model, + ClusteringConfiguration( + target_num_clusters, + layers_to_cluster, + ), + ) + clusterer.apply_optimization() + clustered_model = clusterer.get_model() + + temp_file = tmp_path / "test_cluster_simple_model_fully_after.tflite" + tflite_clustered_model = convert_to_tflite(clustered_model) + save_tflite_model(tflite_clustered_model, temp_file) + clustered_tflite_metrics = TFLiteMetrics(str(temp_file)) + + _test_num_unique_weights( + clustered_tflite_metrics, target_num_clusters, layers_to_cluster + ) + + if sparsity_aware: + _test_sparsity(clustered_tflite_metrics, target_sparsity, layers_to_cluster) diff --git a/tests/mlia/test_nn_tensorflow_optimizations_pruning.py b/tests/mlia/test_nn_tensorflow_optimizations_pruning.py new file mode 100644 index 0000000..64030a6 --- /dev/null +++ b/tests/mlia/test_nn_tensorflow_optimizations_pruning.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Test for module optimizations/pruning.""" +from pathlib import Path +from typing import List +from typing import Optional + +import pytest +import tensorflow as tf +from numpy.core.numeric import isclose + +from mlia.nn.tensorflow.optimizations.pruning import Pruner +from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration +from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics +from mlia.nn.tensorflow.utils import convert_to_tflite +from mlia.nn.tensorflow.utils import save_tflite_model +from tests.mlia.utils.common import get_dataset +from tests.mlia.utils.common import train_model + + +def _test_sparsity( + metrics: TFLiteMetrics, + target_sparsity: float, + layers_to_prune: Optional[List[str]], +) -> None: + pruned_sparsity_dict = metrics.sparsity_per_layer() + num_sparse_layers = 0 + num_optimizable_layers = len(pruned_sparsity_dict) + error_margin = 0.03 + if layers_to_prune: + expected_num_sparse_layers = len(layers_to_prune) + else: + expected_num_sparse_layers = num_optimizable_layers + for layer_name in pruned_sparsity_dict: + if abs(pruned_sparsity_dict[layer_name] - target_sparsity) < error_margin: + num_sparse_layers = num_sparse_layers + 1 + # make sure we are having exactly as many sparse layers as we wanted + assert num_sparse_layers == expected_num_sparse_layers + + +def _test_check_sparsity(base_tflite_metrics: TFLiteMetrics) -> None: + """Assert the sparsity of a model is zero.""" + base_sparsity_dict = base_tflite_metrics.sparsity_per_layer() + for layer_name, sparsity in base_sparsity_dict.items(): + assert isclose( + sparsity, 0, atol=1e-2 + ), f"Sparsity for layer '{layer_name}' is {sparsity}, but should be zero." + + +def _get_tflite_metrics( + path: Path, tflite_fn: str, model: tf.keras.Model +) -> TFLiteMetrics: + """Save model as TFLiteModel and return metrics.""" + temp_file = path / tflite_fn + save_tflite_model(convert_to_tflite(model), temp_file) + return TFLiteMetrics(str(temp_file)) + + +@pytest.mark.parametrize("target_sparsity", (0.5, 0.9)) +@pytest.mark.parametrize("mock_data", (False, True)) +@pytest.mark.parametrize("layers_to_prune", (["conv1"], ["conv1", "conv2"], None)) +def test_prune_simple_model_fully( + target_sparsity: float, + mock_data: bool, + layers_to_prune: Optional[List[str]], + tmp_path: Path, + test_keras_model: Path, +) -> None: + """Simple MNIST test to see if pruning works correctly.""" + x_train, y_train = get_dataset() + batch_size = 1 + num_epochs = 1 + + base_model = tf.keras.models.load_model(str(test_keras_model)) + train_model(base_model) + + base_tflite_metrics = _get_tflite_metrics( + path=tmp_path, + tflite_fn="test_prune_simple_model_fully_before.tflite", + model=base_model, + ) + + # Make sure sparsity is zero before pruning + _test_check_sparsity(base_tflite_metrics) + + if mock_data: + pruner = Pruner( + base_model, + PruningConfiguration( + target_sparsity, + layers_to_prune, + ), + ) + + else: + pruner = Pruner( + base_model, + PruningConfiguration( + target_sparsity, + layers_to_prune, + x_train, + y_train, + batch_size, + num_epochs, + ), + ) + + pruner.apply_optimization() + pruned_model = pruner.get_model() + + pruned_tflite_metrics = _get_tflite_metrics( + path=tmp_path, + tflite_fn="test_prune_simple_model_fully_after.tflite", + model=pruned_model, + ) + + _test_sparsity(pruned_tflite_metrics, target_sparsity, layers_to_prune) diff --git a/tests/mlia/test_nn_tensorflow_optimizations_select.py b/tests/mlia/test_nn_tensorflow_optimizations_select.py new file mode 100644 index 0000000..5cac8ba --- /dev/null +++ b/tests/mlia/test_nn_tensorflow_optimizations_select.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module select.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import List +from typing import Tuple + +import pytest +import tensorflow as tf + +from mlia.nn.tensorflow.optimizations.clustering import Clusterer +from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration +from mlia.nn.tensorflow.optimizations.pruning import Pruner +from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration +from mlia.nn.tensorflow.optimizations.select import get_optimizer +from mlia.nn.tensorflow.optimizations.select import MultiStageOptimizer +from mlia.nn.tensorflow.optimizations.select import OptimizationSettings + + +@pytest.mark.parametrize( + "config, expected_error, expected_type, expected_config", + [ + ( + OptimizationSettings( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ), + does_not_raise(), + Pruner, + "pruning: 0.5", + ), + ( + PruningConfiguration(0.5), + does_not_raise(), + Pruner, + "pruning: 0.5", + ), + ( + OptimizationSettings( + optimization_type="clustering", + optimization_target=32, + layers_to_optimize=None, + ), + does_not_raise(), + Clusterer, + "clustering: 32", + ), + ( + OptimizationSettings( + optimization_type="clustering", + optimization_target=0.5, + layers_to_optimize=None, + ), + pytest.raises( + Exception, + match="Optimization target should be a " + "positive integer. " + "Optimization target provided: 0.5", + ), + None, + None, + ), + ( + ClusteringConfiguration(32), + does_not_raise(), + Clusterer, + "clustering: 32", + ), + ( + OptimizationSettings( + optimization_type="superoptimization", + optimization_target="supertarget", # type: ignore + layers_to_optimize="all", # type: ignore + ), + pytest.raises( + Exception, + match="Unsupported optimization type: superoptimization", + ), + None, + None, + ), + ( + OptimizationSettings( + optimization_type="", + optimization_target=0.5, + layers_to_optimize=None, + ), + pytest.raises( + Exception, + match="Optimization type is not provided", + ), + None, + None, + ), + ( + "wrong_config", + pytest.raises( + Exception, + match="Unknown optimization configuration wrong_config", + ), + None, + None, + ), + ( + OptimizationSettings( + optimization_type="pruning", + optimization_target=None, # type: ignore + layers_to_optimize=None, + ), + pytest.raises( + Exception, + match="Optimization target is not provided", + ), + None, + None, + ), + ( + [ + OptimizationSettings( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ), + OptimizationSettings( + optimization_type="clustering", + optimization_target=32, + layers_to_optimize=None, + ), + ], + does_not_raise(), + MultiStageOptimizer, + "pruning: 0.5 - clustering: 32", + ), + ], +) +def test_get_optimizer( + config: Any, + expected_error: Any, + expected_type: type, + expected_config: str, + test_keras_model: Path, +) -> None: + """Test function get_optimzer.""" + model = tf.keras.models.load_model(str(test_keras_model)) + + with expected_error: + optimizer = get_optimizer(model, config) + assert isinstance(optimizer, expected_type) + assert optimizer.optimization_config() == expected_config + + +@pytest.mark.parametrize( + "params, expected_result", + [ + ( + [], + [], + ), + ( + [("pruning", 0.5)], + [ + OptimizationSettings( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ) + ], + ), + ( + [("pruning", 0.5), ("clustering", 32)], + [ + OptimizationSettings( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ), + OptimizationSettings( + optimization_type="clustering", + optimization_target=32, + layers_to_optimize=None, + ), + ], + ), + ], +) +def test_optimization_settings_create_from( + params: List[Tuple[str, float]], expected_result: List[OptimizationSettings] +) -> None: + """Test creating settings from parsed params.""" + assert OptimizationSettings.create_from(params) == expected_result + + +@pytest.mark.parametrize( + "settings, expected_next_target, expected_error", + [ + [ + OptimizationSettings("clustering", 32, None), + OptimizationSettings("clustering", 16, None), + does_not_raise(), + ], + [ + OptimizationSettings("clustering", 4, None), + OptimizationSettings("clustering", 4, None), + does_not_raise(), + ], + [ + OptimizationSettings("clustering", 10, None), + OptimizationSettings("clustering", 8, None), + does_not_raise(), + ], + [ + OptimizationSettings("pruning", 0.5, None), + OptimizationSettings("pruning", 0.6, None), + does_not_raise(), + ], + [ + OptimizationSettings("pruning", 0.9, None), + OptimizationSettings("pruning", 0.9, None), + does_not_raise(), + ], + [ + OptimizationSettings("super_optimization", 42, None), + None, + pytest.raises( + Exception, match="Unknown optimization type super_optimization" + ), + ], + ], +) +def test_optimization_settings_next_target( + settings: OptimizationSettings, + expected_next_target: OptimizationSettings, + expected_error: Any, +) -> None: + """Test getting next optimization target.""" + with expected_error: + assert settings.next_target() == expected_next_target diff --git a/tests/mlia/test_nn_tensorflow_tflite_metrics.py b/tests/mlia/test_nn_tensorflow_tflite_metrics.py new file mode 100644 index 0000000..805f7d1 --- /dev/null +++ b/tests/mlia/test_nn_tensorflow_tflite_metrics.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Test for module utils/tflite_metrics.""" +import os +import tempfile +from math import isclose +from pathlib import Path +from typing import Generator +from typing import List + +import numpy as np +import pytest +import tensorflow as tf + +from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode +from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics + + +def _dummy_keras_model() -> tf.keras.Model: + # Create a dummy model + keras_model = tf.keras.Sequential( + [ + tf.keras.Input(shape=(8, 8, 3)), + tf.keras.layers.Conv2D(4, 3), + tf.keras.layers.DepthwiseConv2D(3), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(8), + ] + ) + return keras_model + + +def _sparse_binary_keras_model() -> tf.keras.Model: + def get_sparse_weights(shape: List[int]) -> np.array: + weights = np.zeros(shape) + with np.nditer(weights, op_flags=["writeonly"]) as weight_iterator: + for idx, value in enumerate(weight_iterator): + if idx % 2 == 0: + value[...] = 1.0 + return weights + + keras_model = _dummy_keras_model() + # Assign weights to have 0.5 sparsity + for layer in keras_model.layers: + if not isinstance(layer, tf.keras.layers.Flatten): + weight = layer.weights[0] + weight.assign(get_sparse_weights(weight.shape)) + print(layer) + print(weight.numpy()) + return keras_model + + +@pytest.fixture(scope="class", name="tflite_file") +def fixture_tflite_file() -> Generator: + """Generate temporary TFLite file for tests.""" + converter = tf.lite.TFLiteConverter.from_keras_model(_sparse_binary_keras_model()) + tflite_model = converter.convert() + with tempfile.TemporaryDirectory() as tmp_dir: + file = os.path.join(tmp_dir, "test.tflite") + Path(file).write_bytes(tflite_model) + yield file + + +@pytest.fixture(scope="function", name="metrics") +def fixture_metrics(tflite_file: str) -> TFLiteMetrics: + """Generate metrics file for a given TFLite model.""" + return TFLiteMetrics(tflite_file) + + +class TestTFLiteMetrics: + """Tests for module TFLite_metrics.""" + + @staticmethod + def test_sparsity(metrics: TFLiteMetrics) -> None: + """Test sparsity.""" + # Create new instance with a dummy TFLite file + # Check sparsity calculation + sparsity_per_layer = metrics.sparsity_per_layer() + for name, sparsity in sparsity_per_layer.items(): + assert isclose(sparsity, 0.5), "Layer '{}' has incorrect sparsity.".format( + name + ) + assert isclose(metrics.sparsity_overall(), 0.5) + + @staticmethod + def test_clusters(metrics: TFLiteMetrics) -> None: + """Test clusters.""" + # NUM_CLUSTERS_PER_AXIS and NUM_CLUSTERS_MIN_MAX can be handled together + for mode in [ + ReportClusterMode.NUM_CLUSTERS_PER_AXIS, + ReportClusterMode.NUM_CLUSTERS_MIN_MAX, + ]: + num_unique_weights = metrics.num_unique_weights(mode) + for name, num_unique_per_axis in num_unique_weights.items(): + for num_unique in num_unique_per_axis: + assert ( + num_unique == 2 + ), "Layer '{}' has incorrect number of clusters.".format(name) + # NUM_CLUSTERS_HISTOGRAM + hists = metrics.num_unique_weights(ReportClusterMode.NUM_CLUSTERS_HISTOGRAM) + assert hists + for name, hist in hists.items(): + assert hist + for idx, num_axes in enumerate(hist): + # The histogram starts with the bin for for num_clusters == 1 + num_clusters = idx + 1 + msg = ( + "Histogram of layer '{}': There are {} axes with {} " + "clusters".format(name, num_axes, num_clusters) + ) + if num_clusters == 2: + assert num_axes > 0, "{}, but there should be at least one.".format( + msg + ) + else: + assert num_axes == 0, "{}, but there should be none.".format(msg) + + @staticmethod + @pytest.mark.parametrize("report_sparsity", (False, True)) + @pytest.mark.parametrize("report_cluster_mode", ReportClusterMode) + @pytest.mark.parametrize("max_num_clusters", (-1, 8)) + @pytest.mark.parametrize("verbose", (False, True)) + def test_summary( + tflite_file: str, + report_sparsity: bool, + report_cluster_mode: ReportClusterMode, + max_num_clusters: int, + verbose: bool, + ) -> None: + """Test the summary function.""" + for metrics in [TFLiteMetrics(tflite_file), TFLiteMetrics(tflite_file, [])]: + metrics.summary( + report_sparsity=report_sparsity, + report_cluster_mode=report_cluster_mode, + max_num_clusters=max_num_clusters, + verbose=verbose, + ) diff --git a/tests/mlia/test_nn_tensorflow_utils.py b/tests/mlia/test_nn_tensorflow_utils.py new file mode 100644 index 0000000..6d27299 --- /dev/null +++ b/tests/mlia/test_nn_tensorflow_utils.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Test for module utils/test_utils.""" +from pathlib import Path + +import pytest +import tensorflow as tf + +from mlia.nn.tensorflow.utils import convert_to_tflite +from mlia.nn.tensorflow.utils import get_tf_tensor_shape +from mlia.nn.tensorflow.utils import is_keras_model +from mlia.nn.tensorflow.utils import is_tflite_model +from mlia.nn.tensorflow.utils import save_keras_model +from mlia.nn.tensorflow.utils import save_tflite_model + + +def test_convert_to_tflite(test_keras_model: Path) -> None: + """Test converting Keras model to TFLite.""" + keras_model = tf.keras.models.load_model(str(test_keras_model)) + tflite_model = convert_to_tflite(keras_model) + + assert tflite_model + + +def test_save_keras_model(tmp_path: Path, test_keras_model: Path) -> None: + """Test saving Keras model.""" + keras_model = tf.keras.models.load_model(str(test_keras_model)) + + temp_file = tmp_path / "test_model_saving.h5" + save_keras_model(keras_model, temp_file) + loaded_model = tf.keras.models.load_model(temp_file) + + assert loaded_model.summary() == keras_model.summary() + + +def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None: + """Test saving TFLite model.""" + keras_model = tf.keras.models.load_model(str(test_keras_model)) + + tflite_model = convert_to_tflite(keras_model) + + temp_file = tmp_path / "test_model_saving.tflite" + save_tflite_model(tflite_model, temp_file) + + interpreter = tf.lite.Interpreter(model_path=str(temp_file)) + assert interpreter + + +@pytest.mark.parametrize( + "model_path, expected_result", + [ + [Path("sample_model.tflite"), True], + [Path("strange_model.tflite.tfl"), False], + [Path("sample_model.h5"), False], + [Path("sample_model"), False], + ], +) +def test_is_tflite_model(model_path: Path, expected_result: bool) -> None: + """Test function is_tflite_model.""" + result = is_tflite_model(model_path) + assert result == expected_result + + +@pytest.mark.parametrize( + "model_path, expected_result", + [ + [Path("sample_model.h5"), True], + [Path("strange_model.h5.keras"), False], + [Path("sample_model.tflite"), False], + [Path("sample_model"), False], + ], +) +def test_is_keras_model(model_path: Path, expected_result: bool) -> None: + """Test function is_keras_model.""" + result = is_keras_model(model_path) + assert result == expected_result + + +def test_get_tf_tensor_shape(test_tf_model: Path) -> None: + """Test get_tf_tensor_shape with test model.""" + assert get_tf_tensor_shape(str(test_tf_model)) == [1, 28, 28, 1] diff --git a/tests/mlia/test_resources/vela/sample_vela.ini b/tests/mlia/test_resources/vela/sample_vela.ini new file mode 100644 index 0000000..c992458 --- /dev/null +++ b/tests/mlia/test_resources/vela/sample_vela.ini @@ -0,0 +1,47 @@ +; SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +; SPDX-License-Identifier: Apache-2.0 +; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s) +[System_Config.Ethos_U55_High_End_Embedded] +core_clock=500e6 +axi0_port=Sram +axi1_port=OffChipFlash +Sram_clock_scale=1.0 +Sram_burst_length=32 +Sram_read_latency=32 +Sram_write_latency=32 +OffChipFlash_clock_scale=0.125 +OffChipFlash_burst_length=128 +OffChipFlash_read_latency=64 +OffChipFlash_write_latency=64 + +; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s) +[System_Config.Ethos_U65_High_End] +core_clock=1e9 +axi0_port=Sram +axi1_port=Dram +Sram_clock_scale=1.0 +Sram_burst_length=32 +Sram_read_latency=32 +Sram_write_latency=32 +Dram_clock_scale=0.234375 +Dram_burst_length=128 +Dram_read_latency=500 +Dram_write_latency=250 + +; ----------------------------------------------------------------------------- +; Memory Mode + +; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software +; The non-SRAM memory is assumed to be read-only +[Memory_Mode.Shared_Sram] +const_mem_area=Axi1 +arena_mem_area=Axi0 +cache_mem_area=Axi0 + +; The SRAM (384KB) is only for use by the Ethos-U +; The non-SRAM memory is assumed to be read-writeable +[Memory_Mode.Dedicated_Sram] +const_mem_area=Axi1 +arena_mem_area=Axi1 +cache_mem_area=Axi0 +arena_cache_size=393216 diff --git a/tests/mlia/test_tools_aiet_wrapper.py b/tests/mlia/test_tools_aiet_wrapper.py new file mode 100644 index 0000000..ab55b71 --- /dev/null +++ b/tests/mlia/test_tools_aiet_wrapper.py @@ -0,0 +1,760 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module tools/aiet_wrapper.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from unittest.mock import MagicMock +from unittest.mock import PropertyMock + +import pytest + +from mlia.tools.aiet_wrapper import AIETRunner +from mlia.tools.aiet_wrapper import DeviceInfo +from mlia.tools.aiet_wrapper import estimate_performance +from mlia.tools.aiet_wrapper import ExecutionParams +from mlia.tools.aiet_wrapper import GenericInferenceOutputParser +from mlia.tools.aiet_wrapper import GenericInferenceRunnerEthosU +from mlia.tools.aiet_wrapper import get_aiet_runner +from mlia.tools.aiet_wrapper import get_generic_runner +from mlia.tools.aiet_wrapper import get_system_name +from mlia.tools.aiet_wrapper import is_supported +from mlia.tools.aiet_wrapper import ModelInfo +from mlia.tools.aiet_wrapper import PerformanceMetrics +from mlia.tools.aiet_wrapper import supported_backends +from mlia.utils.proc import RunningCommand + + +@pytest.mark.parametrize( + "data, is_ready, result, missed_keys", + [ + ( + [], + False, + {}, + [ + "npu_active_cycles", + "npu_axi0_rd_data_beat_received", + "npu_axi0_wr_data_beat_written", + "npu_axi1_rd_data_beat_received", + "npu_idle_cycles", + "npu_total_cycles", + ], + ), + ( + ["sample text"], + False, + {}, + [ + "npu_active_cycles", + "npu_axi0_rd_data_beat_received", + "npu_axi0_wr_data_beat_written", + "npu_axi1_rd_data_beat_received", + "npu_idle_cycles", + "npu_total_cycles", + ], + ), + ( + [ + ["NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 123"], + False, + {"npu_axi0_rd_data_beat_received": 123}, + [ + "npu_active_cycles", + "npu_axi0_wr_data_beat_written", + "npu_axi1_rd_data_beat_received", + "npu_idle_cycles", + "npu_total_cycles", + ], + ] + ), + ( + [ + "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1", + "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2", + "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3", + "NPU ACTIVE cycles: 4", + "NPU IDLE cycles: 5", + "NPU TOTAL cycles: 6", + ], + True, + { + "npu_axi0_rd_data_beat_received": 1, + "npu_axi0_wr_data_beat_written": 2, + "npu_axi1_rd_data_beat_received": 3, + "npu_active_cycles": 4, + "npu_idle_cycles": 5, + "npu_total_cycles": 6, + }, + [], + ), + ], +) +def test_generic_inference_output_parser( + data: List[str], is_ready: bool, result: Dict, missed_keys: List[str] +) -> None: + """Test generic runner output parser.""" + parser = GenericInferenceOutputParser() + + for line in data: + parser.feed(line) + + assert parser.is_ready() == is_ready + assert parser.result == result + assert parser.missed_keys() == missed_keys + + +class TestAIETRunner: + """Tests for AIETRunner class.""" + + @staticmethod + def _setup_aiet( + monkeypatch: pytest.MonkeyPatch, + available_systems: Optional[List[str]] = None, + available_apps: Optional[List[str]] = None, + ) -> None: + """Set up AIET metadata.""" + + def mock_system(system: str) -> MagicMock: + """Mock the System instance.""" + mock = MagicMock() + type(mock).name = PropertyMock(return_value=system) + return mock + + def mock_app(app: str) -> MagicMock: + """Mock the Application instance.""" + mock = MagicMock() + type(mock).name = PropertyMock(return_value=app) + mock.can_run_on.return_value = True + return mock + + system_mocks = [mock_system(name) for name in (available_systems or [])] + monkeypatch.setattr( + "mlia.tools.aiet_wrapper.get_available_systems", + MagicMock(return_value=system_mocks), + ) + + apps_mock = [mock_app(name) for name in (available_apps or [])] + monkeypatch.setattr( + "mlia.tools.aiet_wrapper.get_available_applications", + MagicMock(return_value=apps_mock), + ) + + @pytest.mark.parametrize( + "available_systems, system, installed", + [ + ([], "system1", False), + (["system1", "system2"], "system1", True), + ], + ) + def test_is_system_installed( + self, + available_systems: List, + system: str, + installed: bool, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test method is_system_installed.""" + mock_executor = MagicMock() + aiet_runner = AIETRunner(mock_executor) + + self._setup_aiet(monkeypatch, available_systems) + + assert aiet_runner.is_system_installed(system) == installed + mock_executor.assert_not_called() + + @pytest.mark.parametrize( + "available_systems, systems", + [ + ([], []), + (["system1"], ["system1"]), + ], + ) + def test_installed_systems( + self, + available_systems: List[str], + systems: List[str], + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test method installed_systems.""" + mock_executor = MagicMock() + aiet_runner = AIETRunner(mock_executor) + + self._setup_aiet(monkeypatch, available_systems) + assert aiet_runner.get_installed_systems() == systems + + mock_executor.assert_not_called() + + @staticmethod + def test_install_system(monkeypatch: pytest.MonkeyPatch) -> None: + """Test system installation.""" + install_system_mock = MagicMock() + monkeypatch.setattr( + "mlia.tools.aiet_wrapper.install_system", install_system_mock + ) + + mock_executor = MagicMock() + aiet_runner = AIETRunner(mock_executor) + aiet_runner.install_system(Path("test_system_path")) + + install_system_mock.assert_called_once_with(Path("test_system_path")) + mock_executor.assert_not_called() + + @pytest.mark.parametrize( + "available_systems, systems, expected_result", + [ + ([], [], False), + (["system1"], [], False), + (["system1"], ["system1"], True), + (["system1", "system2"], ["system1", "system3"], False), + (["system1", "system2"], ["system1", "system2"], True), + ], + ) + def test_systems_installed( + self, + available_systems: List[str], + systems: List[str], + expected_result: bool, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test method systems_installed.""" + self._setup_aiet(monkeypatch, available_systems) + + mock_executor = MagicMock() + aiet_runner = AIETRunner(mock_executor) + + assert aiet_runner.systems_installed(systems) is expected_result + + mock_executor.assert_not_called() + + @pytest.mark.parametrize( + "available_apps, applications, expected_result", + [ + ([], [], False), + (["app1"], [], False), + (["app1"], ["app1"], True), + (["app1", "app2"], ["app1", "app3"], False), + (["app1", "app2"], ["app1", "app2"], True), + ], + ) + def test_applications_installed( + self, + available_apps: List[str], + applications: List[str], + expected_result: bool, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test method applications_installed.""" + self._setup_aiet(monkeypatch, [], available_apps) + mock_executor = MagicMock() + aiet_runner = AIETRunner(mock_executor) + + assert aiet_runner.applications_installed(applications) is expected_result + mock_executor.assert_not_called() + + @pytest.mark.parametrize( + "available_apps, applications", + [ + ([], []), + ( + ["application1", "application2"], + ["application1", "application2"], + ), + ], + ) + def test_get_installed_applications( + self, + available_apps: List[str], + applications: List[str], + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test method get_installed_applications.""" + mock_executor = MagicMock() + self._setup_aiet(monkeypatch, [], available_apps) + + aiet_runner = AIETRunner(mock_executor) + assert applications == aiet_runner.get_installed_applications() + + mock_executor.assert_not_called() + + @staticmethod + def test_install_application(monkeypatch: pytest.MonkeyPatch) -> None: + """Test application installation.""" + mock_install_application = MagicMock() + monkeypatch.setattr( + "mlia.tools.aiet_wrapper.install_application", mock_install_application + ) + + mock_executor = MagicMock() + + aiet_runner = AIETRunner(mock_executor) + aiet_runner.install_application(Path("test_application_path")) + mock_install_application.assert_called_once_with(Path("test_application_path")) + + mock_executor.assert_not_called() + + @pytest.mark.parametrize( + "available_apps, application, installed", + [ + ([], "system1", False), + ( + ["application1", "application2"], + "application1", + True, + ), + ( + [], + "application1", + False, + ), + ], + ) + def test_is_application_installed( + self, + available_apps: List[str], + application: str, + installed: bool, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test method is_application_installed.""" + self._setup_aiet(monkeypatch, [], available_apps) + + mock_executor = MagicMock() + aiet_runner = AIETRunner(mock_executor) + assert installed == aiet_runner.is_application_installed(application, "system1") + + mock_executor.assert_not_called() + + @staticmethod + @pytest.mark.parametrize( + "execution_params, expected_command", + [ + ( + ExecutionParams("application1", "system1", [], [], []), + ["aiet", "application", "run", "-n", "application1", "-s", "system1"], + ), + ( + ExecutionParams( + "application1", + "system1", + ["input_file=123.txt", "size=777"], + ["param1=456", "param2=789"], + ["source1.txt:dest1.txt", "source2.txt:dest2.txt"], + ), + [ + "aiet", + "application", + "run", + "-n", + "application1", + "-s", + "system1", + "-p", + "input_file=123.txt", + "-p", + "size=777", + "--system-param", + "param1=456", + "--system-param", + "param2=789", + "--deploy", + "source1.txt:dest1.txt", + "--deploy", + "source2.txt:dest2.txt", + ], + ), + ], + ) + def test_run_application( + execution_params: ExecutionParams, expected_command: List[str] + ) -> None: + """Test method run_application.""" + mock_executor = MagicMock() + mock_running_command = MagicMock() + mock_executor.submit.return_value = mock_running_command + + aiet_runner = AIETRunner(mock_executor) + aiet_runner.run_application(execution_params) + + mock_executor.submit.assert_called_once_with(expected_command) + + +@pytest.mark.parametrize( + "device, system, application, backend, expected_error", + [ + ( + DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), + ("Corstone-300: Cortex-M55+Ethos-U55", True), + ("Generic Inference Runner: Ethos-U55 SRAM", True), + "Corstone-300", + does_not_raise(), + ), + ( + DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), + ("Corstone-300: Cortex-M55+Ethos-U55", False), + ("Generic Inference Runner: Ethos-U55 SRAM", False), + "Corstone-300", + pytest.raises( + Exception, + match=r"System Corstone-300: Cortex-M55\+Ethos-U55 is not installed", + ), + ), + ( + DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), + ("Corstone-300: Cortex-M55+Ethos-U55", True), + ("Generic Inference Runner: Ethos-U55 SRAM", False), + "Corstone-300", + pytest.raises( + Exception, + match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM " + r"for the system Corstone-300: Cortex-M55\+Ethos-U55 is not installed", + ), + ), + ( + DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), + ("Corstone-310: Cortex-M85+Ethos-U55", True), + ("Generic Inference Runner: Ethos-U55 SRAM", True), + "Corstone-310", + does_not_raise(), + ), + ( + DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), + ("Corstone-310: Cortex-M85+Ethos-U55", False), + ("Generic Inference Runner: Ethos-U55 SRAM", False), + "Corstone-310", + pytest.raises( + Exception, + match=r"System Corstone-310: Cortex-M85\+Ethos-U55 is not installed", + ), + ), + ( + DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), + ("Corstone-310: Cortex-M85+Ethos-U55", True), + ("Generic Inference Runner: Ethos-U55 SRAM", False), + "Corstone-310", + pytest.raises( + Exception, + match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM " + r"for the system Corstone-310: Cortex-M85\+Ethos-U55 is not installed", + ), + ), + ( + DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"), + ("Corstone-300: Cortex-M55+Ethos-U65", True), + ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", True), + "Corstone-300", + does_not_raise(), + ), + ( + DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"), + ("Corstone-300: Cortex-M55+Ethos-U65", False), + ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", False), + "Corstone-300", + pytest.raises( + Exception, + match=r"System Corstone-300: Cortex-M55\+Ethos-U65 is not installed", + ), + ), + ( + DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"), + ("Corstone-300: Cortex-M55+Ethos-U65", True), + ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", False), + "Corstone-300", + pytest.raises( + Exception, + match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM " + r"for the system Corstone-300: Cortex-M55\+Ethos-U65 is not installed", + ), + ), + ( + DeviceInfo( + device_type="unknown_device", # type: ignore + mac=None, # type: ignore + memory_mode="Shared_Sram", + ), + ("some_system", False), + ("some_application", False), + "some backend", + pytest.raises(Exception, match="Unsupported device unknown_device"), + ), + ], +) +def test_estimate_performance( + device: DeviceInfo, + system: Tuple[str, bool], + application: Tuple[str, bool], + backend: str, + expected_error: Any, + test_tflite_model: Path, + aiet_runner: MagicMock, +) -> None: + """Test getting performance estimations.""" + system_name, system_installed = system + application_name, application_installed = application + + aiet_runner.is_system_installed.return_value = system_installed + aiet_runner.is_application_installed.return_value = application_installed + + mock_process = create_mock_process( + [ + "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1", + "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2", + "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3", + "NPU ACTIVE cycles: 4", + "NPU IDLE cycles: 5", + "NPU TOTAL cycles: 6", + ], + [], + ) + + mock_generic_inference_run = RunningCommand(mock_process) + aiet_runner.run_application.return_value = mock_generic_inference_run + + with expected_error: + perf_metrics = estimate_performance( + ModelInfo(test_tflite_model), device, backend + ) + + assert isinstance(perf_metrics, PerformanceMetrics) + assert perf_metrics == PerformanceMetrics( + npu_axi0_rd_data_beat_received=1, + npu_axi0_wr_data_beat_written=2, + npu_axi1_rd_data_beat_received=3, + npu_active_cycles=4, + npu_idle_cycles=5, + npu_total_cycles=6, + ) + + assert aiet_runner.is_system_installed.called_once_with(system_name) + assert aiet_runner.is_application_installed.called_once_with( + application_name, system_name + ) + + +@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) +def test_estimate_performance_insufficient_data( + aiet_runner: MagicMock, test_tflite_model: Path, backend: str +) -> None: + """Test that performance could not be estimated when not all data presented.""" + aiet_runner.is_system_installed.return_value = True + aiet_runner.is_application_installed.return_value = True + + no_total_cycles_output = [ + "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1", + "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2", + "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3", + "NPU ACTIVE cycles: 4", + "NPU IDLE cycles: 5", + ] + mock_process = create_mock_process( + no_total_cycles_output, + [], + ) + + mock_generic_inference_run = RunningCommand(mock_process) + aiet_runner.run_application.return_value = mock_generic_inference_run + + with pytest.raises( + Exception, match="Unable to get performance metrics, insufficient data" + ): + device = DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram") + estimate_performance(ModelInfo(test_tflite_model), device, backend) + + +@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) +def test_estimate_performance_invalid_output( + test_tflite_model: Path, aiet_runner: MagicMock, backend: str +) -> None: + """Test estimation could not be done if inference produces unexpected output.""" + aiet_runner.is_system_installed.return_value = True + aiet_runner.is_application_installed.return_value = True + + mock_process = create_mock_process( + ["Something", "is", "wrong"], ["What a nice error!"] + ) + aiet_runner.run_application.return_value = RunningCommand(mock_process) + + with pytest.raises(Exception, match="Unable to get performance metrics"): + estimate_performance( + ModelInfo(test_tflite_model), + DeviceInfo(device_type="ethos-u55", mac=256, memory_mode="Shared_Sram"), + backend=backend, + ) + + +def test_get_aiet_runner() -> None: + """Test getting aiet runner.""" + aiet_runner = get_aiet_runner() + assert isinstance(aiet_runner, AIETRunner) + + +def create_mock_process(stdout: List[str], stderr: List[str]) -> MagicMock: + """Mock underlying process.""" + mock_process = MagicMock() + mock_process.poll.return_value = 0 + type(mock_process).stdout = PropertyMock(return_value=iter(stdout)) + type(mock_process).stderr = PropertyMock(return_value=iter(stderr)) + return mock_process + + +@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) +def test_get_generic_runner(backend: str) -> None: + """Test function get_generic_runner().""" + device_info = DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram") + + runner = get_generic_runner(device_info=device_info, backend=backend) + assert isinstance(runner, GenericInferenceRunnerEthosU) + + with pytest.raises(RuntimeError): + get_generic_runner(device_info=device_info, backend="UNKNOWN_BACKEND") + + +@pytest.mark.parametrize( + ("backend", "device_type"), + ( + ("Corstone-300", "ethos-u55"), + ("Corstone-300", "ethos-u65"), + ("Corstone-310", "ethos-u55"), + ), +) +def test_aiet_backend_support(backend: str, device_type: str) -> None: + """Test AIET backend & device support.""" + assert is_supported(backend) + assert is_supported(backend, device_type) + + assert get_system_name(backend, device_type) + + assert backend in supported_backends() + + +class TestGenericInferenceRunnerEthosU: + """Test for the class GenericInferenceRunnerEthosU.""" + + @staticmethod + @pytest.mark.parametrize( + "device, backend, expected_system, expected_app", + [ + [ + DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), + "Corstone-300", + "Corstone-300: Cortex-M55+Ethos-U55", + "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + ], + [ + DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), + "Corstone-310", + "Corstone-310: Cortex-M85+Ethos-U55", + "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + ], + [ + DeviceInfo("ethos-u55", 256, memory_mode="Sram"), + "Corstone-310", + "Corstone-310: Cortex-M85+Ethos-U55", + "Generic Inference Runner: Ethos-U55 SRAM", + ], + [ + DeviceInfo("ethos-u55", 256, memory_mode="Sram"), + "Corstone-300", + "Corstone-300: Cortex-M55+Ethos-U55", + "Generic Inference Runner: Ethos-U55 SRAM", + ], + [ + DeviceInfo("ethos-u65", 256, memory_mode="Shared_Sram"), + "Corstone-300", + "Corstone-300: Cortex-M55+Ethos-U65", + "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + ], + [ + DeviceInfo("ethos-u65", 256, memory_mode="Dedicated_Sram"), + "Corstone-300", + "Corstone-300: Cortex-M55+Ethos-U65", + "Generic Inference Runner: Ethos-U65 Dedicated SRAM", + ], + ], + ) + def test_artifact_resolver( + device: DeviceInfo, backend: str, expected_system: str, expected_app: str + ) -> None: + """Test artifact resolving based on the provided parameters.""" + generic_runner = get_generic_runner(device, backend) + assert isinstance(generic_runner, GenericInferenceRunnerEthosU) + + assert generic_runner.system_name == expected_system + assert generic_runner.app_name == expected_app + + @staticmethod + def test_artifact_resolver_unsupported_backend() -> None: + """Test that it should be not possible to use unsupported backends.""" + with pytest.raises( + RuntimeError, match="Unsupported device ethos-u65 for backend test_backend" + ): + get_generic_runner( + DeviceInfo("ethos-u65", 256, memory_mode="Shared_Sram"), "test_backend" + ) + + @staticmethod + def test_artifact_resolver_unsupported_memory_mode() -> None: + """Test that it should be not possible to use unsupported memory modes.""" + with pytest.raises( + RuntimeError, match="Unsupported memory mode test_memory_mode" + ): + get_generic_runner( + DeviceInfo( + "ethos-u65", + 256, + memory_mode="test_memory_mode", # type: ignore + ), + "Corstone-300", + ) + + @staticmethod + @pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) + def test_inference_should_fail_if_system_not_installed( + aiet_runner: MagicMock, test_tflite_model: Path, backend: str + ) -> None: + """Test that inference should fail if system is not installed.""" + aiet_runner.is_system_installed.return_value = False + + generic_runner = get_generic_runner( + DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), backend + ) + with pytest.raises( + Exception, + match=r"System Corstone-3[01]0: Cortex-M[58]5\+Ethos-U55 is not installed", + ): + generic_runner.run(ModelInfo(test_tflite_model), []) + + @staticmethod + @pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) + def test_inference_should_fail_is_apps_not_installed( + aiet_runner: MagicMock, test_tflite_model: Path, backend: str + ) -> None: + """Test that inference should fail if apps are not installed.""" + aiet_runner.is_system_installed.return_value = True + aiet_runner.is_application_installed.return_value = False + + generic_runner = get_generic_runner( + DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), backend + ) + with pytest.raises( + Exception, + match="Application Generic Inference Runner: Ethos-U55/65 Shared SRAM" + r" for the system Corstone-3[01]0: Cortex-M[58]5\+Ethos-U55 is not " + r"installed", + ): + generic_runner.run(ModelInfo(test_tflite_model), []) + + +@pytest.fixture(name="aiet_runner") +def fixture_aiet_runner(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + """Mock AIET runner.""" + aiet_runner_mock = MagicMock(spec=AIETRunner) + monkeypatch.setattr( + "mlia.tools.aiet_wrapper.get_aiet_runner", + MagicMock(return_value=aiet_runner_mock), + ) + return aiet_runner_mock diff --git a/tests/mlia/test_tools_metadata_common.py b/tests/mlia/test_tools_metadata_common.py new file mode 100644 index 0000000..7663b83 --- /dev/null +++ b/tests/mlia/test_tools_metadata_common.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for commmon installation related functions.""" +from pathlib import Path +from typing import Any +from typing import List +from typing import Optional +from unittest.mock import call +from unittest.mock import MagicMock +from unittest.mock import PropertyMock + +import pytest + +from mlia.tools.metadata.common import DefaultInstallationManager +from mlia.tools.metadata.common import DownloadAndInstall +from mlia.tools.metadata.common import Installation +from mlia.tools.metadata.common import InstallationType +from mlia.tools.metadata.common import InstallFromPath + + +def get_installation_mock( + name: str, + already_installed: bool = False, + could_be_installed: bool = False, + supported_install_type: Optional[type] = None, +) -> MagicMock: + """Get mock instance for the installation.""" + mock = MagicMock(spec=Installation) + + def supports(install_type: InstallationType) -> bool: + if supported_install_type is None: + return False + + return isinstance(install_type, supported_install_type) + + mock.supports.side_effect = supports + + props = { + "name": name, + "already_installed": already_installed, + "could_be_installed": could_be_installed, + } + for prop, value in props.items(): + setattr(type(mock), prop, PropertyMock(return_value=value)) + + return mock + + +def _already_installed_mock() -> MagicMock: + return get_installation_mock( + name="already_installed", + already_installed=True, + ) + + +def _ready_for_installation_mock() -> MagicMock: + return get_installation_mock( + name="ready_for_installation", + already_installed=False, + could_be_installed=True, + ) + + +def _could_be_downloaded_and_installed_mock() -> MagicMock: + return get_installation_mock( + name="could_be_downloaded_and_installed", + already_installed=False, + could_be_installed=True, + supported_install_type=DownloadAndInstall, + ) + + +def _could_be_installed_from_mock() -> MagicMock: + return get_installation_mock( + name="could_be_installed_from", + already_installed=False, + could_be_installed=True, + supported_install_type=InstallFromPath, + ) + + +def get_installation_manager( + noninteractive: bool, + installations: List[Any], + monkeypatch: pytest.MonkeyPatch, + yes_response: bool = True, +) -> DefaultInstallationManager: + """Get installation manager instance.""" + if not noninteractive: + monkeypatch.setattr( + "mlia.tools.metadata.common.yes", MagicMock(return_value=yes_response) + ) + + return DefaultInstallationManager(installations, noninteractive=noninteractive) + + +def test_installation_manager_filtering() -> None: + """Test default installation manager.""" + already_installed = _already_installed_mock() + ready_for_installation = _ready_for_installation_mock() + could_be_downloaded_and_installed = _could_be_downloaded_and_installed_mock() + + manager = DefaultInstallationManager( + [ + already_installed, + ready_for_installation, + could_be_downloaded_and_installed, + ] + ) + assert manager.already_installed() == [already_installed] + assert manager.ready_for_installation() == [ + ready_for_installation, + could_be_downloaded_and_installed, + ] + assert manager.could_be_downloaded_and_installed() == [ + could_be_downloaded_and_installed + ] + assert manager.could_be_downloaded_and_installed("some_installation") == [] + + +@pytest.mark.parametrize("noninteractive", [True, False]) +@pytest.mark.parametrize( + "install_mock, eula_agreement, backend_name, expected_call", + [ + [ + _could_be_downloaded_and_installed_mock(), + True, + None, + [call(DownloadAndInstall(eula_agreement=True))], + ], + [ + _could_be_downloaded_and_installed_mock(), + False, + None, + [call(DownloadAndInstall(eula_agreement=False))], + ], + [ + _could_be_downloaded_and_installed_mock(), + False, + "unknown", + [], + ], + ], +) +def test_installation_manager_download_and_install( + install_mock: MagicMock, + noninteractive: bool, + eula_agreement: bool, + backend_name: Optional[str], + expected_call: Any, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test installation process.""" + install_mock.reset_mock() + + manager = get_installation_manager(noninteractive, [install_mock], monkeypatch) + + manager.download_and_install(backend_name, eula_agreement=eula_agreement) + assert install_mock.install.mock_calls == expected_call + + +@pytest.mark.parametrize("noninteractive", [True, False]) +@pytest.mark.parametrize( + "install_mock, backend_name, expected_call", + [ + [ + _could_be_installed_from_mock(), + None, + [call(InstallFromPath(Path("some_path")))], + ], + [ + _could_be_installed_from_mock(), + "unknown", + [], + ], + [ + _already_installed_mock(), + "already_installed", + [], + ], + ], +) +def test_installation_manager_install_from( + install_mock: MagicMock, + noninteractive: bool, + backend_name: Optional[str], + expected_call: Any, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test installation process.""" + install_mock.reset_mock() + + manager = get_installation_manager(noninteractive, [install_mock], monkeypatch) + manager.install_from(Path("some_path"), backend_name) + + assert install_mock.install.mock_calls == expected_call diff --git a/tests/mlia/test_tools_metadata_corstone.py b/tests/mlia/test_tools_metadata_corstone.py new file mode 100644 index 0000000..2ce3610 --- /dev/null +++ b/tests/mlia/test_tools_metadata_corstone.py @@ -0,0 +1,419 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for Corstone related installation functions..""" +import tarfile +from pathlib import Path +from typing import List +from typing import Optional +from unittest.mock import MagicMock + +import pytest + +from mlia.tools.aiet_wrapper import AIETRunner +from mlia.tools.metadata.common import DownloadAndInstall +from mlia.tools.metadata.common import InstallFromPath +from mlia.tools.metadata.corstone import AIETBasedInstallation +from mlia.tools.metadata.corstone import AIETMetadata +from mlia.tools.metadata.corstone import BackendInfo +from mlia.tools.metadata.corstone import BackendInstaller +from mlia.tools.metadata.corstone import CompoundPathChecker +from mlia.tools.metadata.corstone import Corstone300Installer +from mlia.tools.metadata.corstone import get_corstone_installations +from mlia.tools.metadata.corstone import PackagePathChecker +from mlia.tools.metadata.corstone import PathChecker +from mlia.tools.metadata.corstone import StaticPathChecker + + +@pytest.fixture(name="test_mlia_resources") +def fixture_test_mlia_resources( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> Path: + """Redirect MLIA resources resolution to the temp directory.""" + mlia_resources = tmp_path / "resources" + mlia_resources.mkdir() + + monkeypatch.setattr( + "mlia.tools.metadata.corstone.get_mlia_resources", + MagicMock(return_value=mlia_resources), + ) + + return mlia_resources + + +def get_aiet_based_installation( # pylint: disable=too-many-arguments + aiet_runner_mock: MagicMock = MagicMock(), + name: str = "test_name", + description: str = "test_description", + download_artifact: Optional[MagicMock] = None, + path_checker: PathChecker = MagicMock(), + apps_resources: Optional[List[str]] = None, + system_config: Optional[str] = None, + backend_installer: BackendInstaller = MagicMock(), + supported_platforms: Optional[List[str]] = None, +) -> AIETBasedInstallation: + """Get AIET based installation.""" + return AIETBasedInstallation( + aiet_runner=aiet_runner_mock, + metadata=AIETMetadata( + name=name, + description=description, + system_config=system_config or "", + apps_resources=apps_resources or [], + fvp_dir_name="sample_dir", + download_artifact=download_artifact, + supported_platforms=supported_platforms, + ), + path_checker=path_checker, + backend_installer=backend_installer, + ) + + +@pytest.mark.parametrize( + "platform, supported_platforms, expected_result", + [ + ["Linux", ["Linux"], True], + ["Linux", [], True], + ["Linux", None, True], + ["Windows", ["Linux"], False], + ], +) +def test_could_be_installed_depends_on_platform( + platform: str, + supported_platforms: Optional[List[str]], + expected_result: bool, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that installation could not be installed on unsupported platform.""" + monkeypatch.setattr( + "mlia.tools.metadata.corstone.platform.system", MagicMock(return_value=platform) + ) + monkeypatch.setattr( + "mlia.tools.metadata.corstone.all_paths_valid", MagicMock(return_value=True) + ) + aiet_runner_mock = MagicMock(spec=AIETRunner) + + installation = get_aiet_based_installation( + aiet_runner_mock, + supported_platforms=supported_platforms, + ) + assert installation.could_be_installed == expected_result + + +def test_get_corstone_installations() -> None: + """Test function get_corstone_installation.""" + installs = get_corstone_installations() + assert len(installs) == 2 + assert all(isinstance(install, AIETBasedInstallation) for install in installs) + + +def test_aiet_based_installation_metadata_resolving() -> None: + """Test AIET based installation metadata resolving.""" + aiet_runner_mock = MagicMock(spec=AIETRunner) + installation = get_aiet_based_installation(aiet_runner_mock) + + assert installation.name == "test_name" + assert installation.description == "test_description" + + aiet_runner_mock.all_installed.return_value = False + assert installation.already_installed is False + + assert installation.could_be_installed is True + + +def test_aiet_based_installation_supported_install_types(tmp_path: Path) -> None: + """Test supported installation types.""" + installation_no_download_artifact = get_aiet_based_installation() + assert installation_no_download_artifact.supports(DownloadAndInstall()) is False + + installation_with_download_artifact = get_aiet_based_installation( + download_artifact=MagicMock() + ) + assert installation_with_download_artifact.supports(DownloadAndInstall()) is True + + path_checker_mock = MagicMock(return_value=BackendInfo(tmp_path)) + installation_can_install_from_dir = get_aiet_based_installation( + path_checker=path_checker_mock + ) + assert installation_can_install_from_dir.supports(InstallFromPath(tmp_path)) is True + + any_installation = get_aiet_based_installation() + assert any_installation.supports("unknown_install_type") is False # type: ignore + + +def test_aiet_based_installation_install_wrong_type() -> None: + """Test that operation should fail if wrong install type provided.""" + with pytest.raises(Exception, match="Unable to install wrong_install_type"): + aiet_runner_mock = MagicMock(spec=AIETRunner) + installation = get_aiet_based_installation(aiet_runner_mock) + + installation.install("wrong_install_type") # type: ignore + + +def test_aiet_based_installation_install_from_path( + tmp_path: Path, test_mlia_resources: Path +) -> None: + """Test installation from the path.""" + system_config = test_mlia_resources / "example_config.json" + system_config.touch() + + sample_app = test_mlia_resources / "sample_app" + sample_app.mkdir() + + dist_dir = tmp_path / "dist" + dist_dir.mkdir() + + path_checker_mock = MagicMock(return_value=BackendInfo(dist_dir)) + + aiet_runner_mock = MagicMock(spec=AIETRunner) + installation = get_aiet_based_installation( + aiet_runner_mock=aiet_runner_mock, + path_checker=path_checker_mock, + apps_resources=[sample_app.name], + system_config="example_config.json", + ) + + assert installation.supports(InstallFromPath(dist_dir)) is True + installation.install(InstallFromPath(dist_dir)) + + aiet_runner_mock.install_system.assert_called_once() + aiet_runner_mock.install_application.assert_called_once_with(sample_app) + + +@pytest.mark.parametrize("copy_source", [True, False]) +def test_aiet_based_installation_install_from_static_path( + tmp_path: Path, test_mlia_resources: Path, copy_source: bool +) -> None: + """Test installation from the predefined path.""" + system_config = test_mlia_resources / "example_config.json" + system_config.touch() + + custom_system_config = test_mlia_resources / "custom_config.json" + custom_system_config.touch() + + sample_app = test_mlia_resources / "sample_app" + sample_app.mkdir() + + predefined_location = tmp_path / "backend" + predefined_location.mkdir() + + predefined_location_file = predefined_location / "file.txt" + predefined_location_file.touch() + + predefined_location_dir = predefined_location / "folder" + predefined_location_dir.mkdir() + nested_file = predefined_location_dir / "nested_file.txt" + nested_file.touch() + + aiet_runner_mock = MagicMock(spec=AIETRunner) + + def check_install_dir(install_dir: Path) -> None: + """Check content of the install dir.""" + assert install_dir.is_dir() + files = list(install_dir.iterdir()) + + if copy_source: + assert len(files) == 3 + assert all(install_dir / item in files for item in ["file.txt", "folder"]) + assert (install_dir / "folder/nested_file.txt").is_file() + else: + assert len(files) == 1 + + assert install_dir / "custom_config.json" in files + + aiet_runner_mock.install_system.side_effect = check_install_dir + + installation = get_aiet_based_installation( + aiet_runner_mock=aiet_runner_mock, + path_checker=StaticPathChecker( + predefined_location, + ["file.txt"], + copy_source=copy_source, + system_config=str(custom_system_config), + ), + apps_resources=[sample_app.name], + system_config="example_config.json", + ) + + assert installation.supports(InstallFromPath(predefined_location)) is True + installation.install(InstallFromPath(predefined_location)) + + aiet_runner_mock.install_system.assert_called_once() + aiet_runner_mock.install_application.assert_called_once_with(sample_app) + + +def create_sample_fvp_archive(tmp_path: Path) -> Path: + """Create sample FVP tar archive.""" + fvp_archive_dir = tmp_path / "archive" + fvp_archive_dir.mkdir() + + sample_file = fvp_archive_dir / "sample.txt" + sample_file.write_text("Sample file") + + sample_dir = fvp_archive_dir / "sample_dir" + sample_dir.mkdir() + + fvp_archive = tmp_path / "archive.tgz" + with tarfile.open(fvp_archive, "w:gz") as fvp_archive_tar: + fvp_archive_tar.add(fvp_archive_dir, arcname=fvp_archive_dir.name) + + return fvp_archive + + +def test_aiet_based_installation_download_and_install( + test_mlia_resources: Path, tmp_path: Path +) -> None: + """Test downloading and installation process.""" + fvp_archive = create_sample_fvp_archive(tmp_path) + + system_config = test_mlia_resources / "example_config.json" + system_config.touch() + + download_artifact_mock = MagicMock() + download_artifact_mock.download_to.return_value = fvp_archive + + path_checker = PackagePathChecker(["archive/sample.txt"], "archive/sample_dir") + + def installer(_eula_agreement: bool, dist_dir: Path) -> Path: + """Sample installer.""" + return dist_dir + + aiet_runner_mock = MagicMock(spec=AIETRunner) + installation = get_aiet_based_installation( + aiet_runner_mock, + download_artifact=download_artifact_mock, + backend_installer=installer, + path_checker=path_checker, + system_config="example_config.json", + ) + + installation.install(DownloadAndInstall()) + + aiet_runner_mock.install_system.assert_called_once() + + +@pytest.mark.parametrize( + "dir_content, expected_result", + [ + [ + ["models/", "file1.txt", "file2.txt"], + "models", + ], + [ + ["file1.txt", "file2.txt"], + None, + ], + [ + ["models/", "file2.txt"], + None, + ], + ], +) +def test_corstone_path_checker_valid_path( + tmp_path: Path, dir_content: List[str], expected_result: Optional[str] +) -> None: + """Test Corstone path checker valid scenario.""" + path_checker = PackagePathChecker(["file1.txt", "file2.txt"], "models") + + for item in dir_content: + if item.endswith("/"): + item_dir = tmp_path / item + item_dir.mkdir() + else: + item_file = tmp_path / item + item_file.touch() + + result = path_checker(tmp_path) + expected = ( + None if expected_result is None else BackendInfo(tmp_path / expected_result) + ) + + assert result == expected + + +@pytest.mark.parametrize("system_config", [None, "system_config"]) +@pytest.mark.parametrize("copy_source", [True, False]) +def test_static_path_checker( + tmp_path: Path, copy_source: bool, system_config: Optional[str] +) -> None: + """Test static path checker.""" + static_checker = StaticPathChecker( + tmp_path, [], copy_source=copy_source, system_config=system_config + ) + assert static_checker(tmp_path) == BackendInfo( + tmp_path, copy_source=copy_source, system_config=system_config + ) + + +def test_static_path_checker_not_valid_path(tmp_path: Path) -> None: + """Test static path checker should return None if path is not valid.""" + static_checker = StaticPathChecker(tmp_path, ["file.txt"]) + assert static_checker(tmp_path / "backend") is None + + +def test_static_path_checker_not_valid_structure(tmp_path: Path) -> None: + """Test static path checker should return None if files are missing.""" + static_checker = StaticPathChecker(tmp_path, ["file.txt"]) + assert static_checker(tmp_path) is None + + missing_file = tmp_path / "file.txt" + missing_file.touch() + + assert static_checker(tmp_path) == BackendInfo(tmp_path, copy_source=False) + + +def test_compound_path_checker(tmp_path: Path) -> None: + """Test compound path checker.""" + path_checker_path_valid_path = MagicMock(return_value=BackendInfo(tmp_path)) + path_checker_path_not_valid_path = MagicMock(return_value=None) + + checker = CompoundPathChecker( + path_checker_path_valid_path, path_checker_path_not_valid_path + ) + assert checker(tmp_path) == BackendInfo(tmp_path) + + checker = CompoundPathChecker(path_checker_path_not_valid_path) + assert checker(tmp_path) is None + + +@pytest.mark.parametrize( + "eula_agreement, expected_command", + [ + [ + True, + [ + "./FVP_Corstone_SSE-300.sh", + "-q", + "-d", + "corstone-300", + ], + ], + [ + False, + [ + "./FVP_Corstone_SSE-300.sh", + "-q", + "-d", + "corstone-300", + "--nointeractive", + "--i-agree-to-the-contained-eula", + ], + ], + ], +) +def test_corstone_300_installer( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + eula_agreement: bool, + expected_command: List[str], +) -> None: + """Test Corstone-300 installer.""" + command_mock = MagicMock() + + monkeypatch.setattr( + "mlia.tools.metadata.corstone.subprocess.check_call", command_mock + ) + installer = Corstone300Installer() + result = installer(eula_agreement, tmp_path) + + command_mock.assert_called_once_with(expected_command) + assert result == tmp_path / "corstone-300" diff --git a/tests/mlia/test_tools_vela_wrapper.py b/tests/mlia/test_tools_vela_wrapper.py new file mode 100644 index 0000000..875d2ff --- /dev/null +++ b/tests/mlia/test_tools_vela_wrapper.py @@ -0,0 +1,285 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module tools/vela_wrapper.""" +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +from ethosu.vela.compiler_driver import TensorAllocator +from ethosu.vela.scheduler import OptimizationStrategy + +from mlia.devices.ethosu.config import EthosUConfiguration +from mlia.tools.vela_wrapper import estimate_performance +from mlia.tools.vela_wrapper import generate_supported_operators_report +from mlia.tools.vela_wrapper import NpuSupported +from mlia.tools.vela_wrapper import Operator +from mlia.tools.vela_wrapper import Operators +from mlia.tools.vela_wrapper import optimize_model +from mlia.tools.vela_wrapper import OptimizedModel +from mlia.tools.vela_wrapper import PerformanceMetrics +from mlia.tools.vela_wrapper import supported_operators +from mlia.tools.vela_wrapper import VelaCompiler +from mlia.tools.vela_wrapper import VelaCompilerOptions +from mlia.utils.proc import working_directory + + +def test_default_vela_compiler() -> None: + """Test default Vela compiler instance.""" + default_compiler_options = VelaCompilerOptions(accelerator_config="ethos-u55-256") + default_compiler = VelaCompiler(default_compiler_options) + + assert default_compiler.config_files is None + assert default_compiler.system_config == "internal-default" + assert default_compiler.memory_mode == "internal-default" + assert default_compiler.accelerator_config == "ethos-u55-256" + assert default_compiler.max_block_dependency == 3 + assert default_compiler.arena_cache_size is None + assert default_compiler.tensor_allocator == TensorAllocator.HillClimb + assert default_compiler.cpu_tensor_alignment == 16 + assert default_compiler.optimization_strategy == OptimizationStrategy.Performance + assert default_compiler.output_dir is None + + assert default_compiler.get_config() == { + "accelerator_config": "ethos-u55-256", + "system_config": "internal-default", + "core_clock": 500000000.0, + "axi0_port": "Sram", + "axi1_port": "OffChipFlash", + "memory_mode": "internal-default", + "const_mem_area": "Axi1", + "arena_mem_area": "Axi0", + "cache_mem_area": "Axi0", + "arena_cache_size": 4294967296, + "permanent_storage_mem_area": "OffChipFlash", + "feature_map_storage_mem_area": "Sram", + "fast_storage_mem_area": "Sram", + "memory_area": { + "Sram": { + "clock_scales": 1.0, + "burst_length": 32, + "read_latency": 32, + "write_latency": 32, + }, + "Dram": { + "clock_scales": 1.0, + "burst_length": 1, + "read_latency": 0, + "write_latency": 0, + }, + "OnChipFlash": { + "clock_scales": 1.0, + "burst_length": 1, + "read_latency": 0, + "write_latency": 0, + }, + "OffChipFlash": { + "clock_scales": 0.125, + "burst_length": 128, + "read_latency": 64, + "write_latency": 64, + }, + }, + } + + +def test_vela_compiler_with_parameters(test_resources_path: Path) -> None: + """Test creation of Vela compiler instance with non-default params.""" + vela_ini_path = str(test_resources_path / "vela/sample_vela.ini") + + compiler_options = VelaCompilerOptions( + config_files=vela_ini_path, + system_config="Ethos_U65_High_End", + memory_mode="Shared_Sram", + accelerator_config="ethos-u65-256", + max_block_dependency=1, + arena_cache_size=10, + tensor_allocator="Greedy", + cpu_tensor_alignment=4, + optimization_strategy="Size", + output_dir="output", + ) + compiler = VelaCompiler(compiler_options) + + assert compiler.config_files == vela_ini_path + assert compiler.system_config == "Ethos_U65_High_End" + assert compiler.memory_mode == "Shared_Sram" + assert compiler.accelerator_config == "ethos-u65-256" + assert compiler.max_block_dependency == 1 + assert compiler.arena_cache_size == 10 + assert compiler.tensor_allocator == TensorAllocator.Greedy + assert compiler.cpu_tensor_alignment == 4 + assert compiler.optimization_strategy == OptimizationStrategy.Size + assert compiler.output_dir == "output" + + assert compiler.get_config() == { + "accelerator_config": "ethos-u65-256", + "system_config": "Ethos_U65_High_End", + "core_clock": 1000000000.0, + "axi0_port": "Sram", + "axi1_port": "Dram", + "memory_mode": "Shared_Sram", + "const_mem_area": "Axi1", + "arena_mem_area": "Axi0", + "cache_mem_area": "Axi0", + "arena_cache_size": 10, + "permanent_storage_mem_area": "Dram", + "feature_map_storage_mem_area": "Sram", + "fast_storage_mem_area": "Sram", + "memory_area": { + "Sram": { + "clock_scales": 1.0, + "burst_length": 32, + "read_latency": 32, + "write_latency": 32, + }, + "Dram": { + "clock_scales": 0.234375, + "burst_length": 128, + "read_latency": 500, + "write_latency": 250, + }, + "OnChipFlash": { + "clock_scales": 1.0, + "burst_length": 1, + "read_latency": 0, + "write_latency": 0, + }, + "OffChipFlash": { + "clock_scales": 1.0, + "burst_length": 1, + "read_latency": 0, + "write_latency": 0, + }, + }, + } + + +def test_compile_model(test_tflite_model: Path) -> None: + """Test model optimization.""" + compiler = VelaCompiler(EthosUConfiguration("ethos-u55-256").compiler_options) + + optimized_model = compiler.compile_model(test_tflite_model) + assert isinstance(optimized_model, OptimizedModel) + + +def test_optimize_model(tmp_path: Path, test_tflite_model: Path) -> None: + """Test model optimization and saving into file.""" + tmp_file = tmp_path / "temp.tflite" + + device = EthosUConfiguration("ethos-u55-256") + optimize_model(test_tflite_model, device.compiler_options, tmp_file.absolute()) + + assert tmp_file.is_file() + assert tmp_file.stat().st_size > 0 + + +@pytest.mark.parametrize( + "model, expected_ops", + [ + ( + "test_model.tflite", + Operators( + ops=[ + Operator( + name="sequential/conv1/Relu;sequential/conv1/BiasAdd;" + "sequential/conv2/Conv2D;sequential/conv1/Conv2D", + op_type="CONV_2D", + run_on_npu=NpuSupported(supported=True, reasons=[]), + ), + Operator( + name="sequential/conv2/Relu;sequential/conv2/BiasAdd;" + "sequential/conv2/Conv2D", + op_type="CONV_2D", + run_on_npu=NpuSupported(supported=True, reasons=[]), + ), + Operator( + name="sequential/max_pooling2d/MaxPool", + op_type="MAX_POOL_2D", + run_on_npu=NpuSupported(supported=True, reasons=[]), + ), + Operator( + name="sequential/flatten/Reshape", + op_type="RESHAPE", + run_on_npu=NpuSupported(supported=True, reasons=[]), + ), + Operator( + name="Identity", + op_type="FULLY_CONNECTED", + run_on_npu=NpuSupported(supported=True, reasons=[]), + ), + ] + ), + ) + ], +) +def test_operators(test_models_path: Path, model: str, expected_ops: Operators) -> None: + """Test operators function.""" + device = EthosUConfiguration("ethos-u55-256") + + operators = supported_operators(test_models_path / model, device.compiler_options) + for expected, actual in zip(expected_ops.ops, operators.ops): + # do not compare names as they could be different on each model generation + assert expected.op_type == actual.op_type + assert expected.run_on_npu == actual.run_on_npu + + +def test_estimate_performance(test_tflite_model: Path) -> None: + """Test getting performance estimations.""" + device = EthosUConfiguration("ethos-u55-256") + perf_metrics = estimate_performance(test_tflite_model, device.compiler_options) + + assert isinstance(perf_metrics, PerformanceMetrics) + + +def test_estimate_performance_already_optimized( + tmp_path: Path, test_tflite_model: Path +) -> None: + """Test that performance estimation should fail for already optimized model.""" + device = EthosUConfiguration("ethos-u55-256") + + optimized_model_path = tmp_path / "optimized_model.tflite" + + optimize_model(test_tflite_model, device.compiler_options, optimized_model_path) + + with pytest.raises( + Exception, match="Unable to estimate performance for the given optimized model" + ): + estimate_performance(optimized_model_path, device.compiler_options) + + +def test_generate_supported_operators_report(tmp_path: Path) -> None: + """Test generating supported operators report.""" + with working_directory(tmp_path): + generate_supported_operators_report() + + md_file = tmp_path / "SUPPORTED_OPS.md" + assert md_file.is_file() + assert md_file.stat().st_size > 0 + + +def test_read_invalid_model(test_tflite_invalid_model: Path) -> None: + """Test that reading invalid model should fail with exception.""" + with pytest.raises( + Exception, match=f"Unable to read model {test_tflite_invalid_model}" + ): + device = EthosUConfiguration("ethos-u55-256") + estimate_performance(test_tflite_invalid_model, device.compiler_options) + + +def test_compile_invalid_model( + test_tflite_model: Path, monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Test that if model could not be compiled then correct exception raised.""" + mock_compiler = MagicMock() + mock_compiler.side_effect = Exception("Bad model!") + + monkeypatch.setattr("mlia.tools.vela_wrapper.compiler_driver", mock_compiler) + + model_path = tmp_path / "optimized_model.tflite" + with pytest.raises( + Exception, match="Model could not be optimized with Vela compiler" + ): + device = EthosUConfiguration("ethos-u55-256") + optimize_model(test_tflite_model, device.compiler_options, model_path) + + assert not model_path.exists() diff --git a/tests/mlia/test_utils_console.py b/tests/mlia/test_utils_console.py new file mode 100644 index 0000000..36975f8 --- /dev/null +++ b/tests/mlia/test_utils_console.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for console utility functions.""" +from typing import Iterable +from typing import List +from typing import Optional + +import pytest + +from mlia.utils.console import apply_style +from mlia.utils.console import create_section_header +from mlia.utils.console import produce_table +from mlia.utils.console import remove_ascii_codes + + +@pytest.mark.parametrize( + "rows, headers, table_style, expected_result", + [ + [[], [], "no_borders", ""], + [ + [["1", "2", "3"]], + ["Col 1", "Col 2", "Col 3"], + "default", + """ +┌───────┬───────┬───────┐ +│ Col 1 │ Col 2 │ Col 3 │ +╞═══════╪═══════╪═══════╡ +│ 1 │ 2 │ 3 │ +└───────┴───────┴───────┘ +""".strip(), + ], + [ + [["1", "2", "3"]], + ["Col 1", "Col 2", "Col 3"], + "nested", + "Col 1 Col 2 Col 3 \n \n1 2 3", + ], + [ + [["1", "2", "3"]], + ["Col 1", "Col 2", "Col 3"], + "no_borders", + " Col 1 Col 2 Col 3 \n 1 2 3", + ], + ], +) +def test_produce_table( + rows: Iterable, headers: Optional[List[str]], table_style: str, expected_result: str +) -> None: + """Test produce_table function.""" + result = produce_table(rows, headers, table_style) + assert remove_ascii_codes(result) == expected_result + + +def test_produce_table_unknown_style() -> None: + """Test that function should fail if unknown style provided.""" + with pytest.raises(Exception, match="Unsupported table style unknown_style"): + produce_table([["1", "2", "3"]], [], "unknown_style") + + +@pytest.mark.parametrize( + "value, expected_result", + [ + ["some text", "some text"], + ["\033[32msome text\033[0m", "some text"], + ], +) +def test_remove_ascii_codes(value: str, expected_result: str) -> None: + """Test remove_ascii_codes function.""" + assert remove_ascii_codes(value) == expected_result + + +def test_apply_style() -> None: + """Test function apply_style.""" + assert apply_style("some text", "green") == "[green]some text" + + +@pytest.mark.parametrize( + "section_header, expected_result", + [ + [ + "Section header", + "\n--- Section header -------------------------------" + "------------------------------\n", + ], + [ + "", + f"\n{'-' * 80}\n", + ], + ], +) +def test_create_section_header(section_header: str, expected_result: str) -> None: + """Test function test_create_section.""" + assert create_section_header(section_header) == expected_result + + +def test_create_section_header_too_long_value() -> None: + """Test that header could not be created for the too long section names.""" + section_name = "section name" * 100 + with pytest.raises(ValueError, match="Section name too long"): + create_section_header(section_name) diff --git a/tests/mlia/test_utils_download.py b/tests/mlia/test_utils_download.py new file mode 100644 index 0000000..4f8e2dc --- /dev/null +++ b/tests/mlia/test_utils_download.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for download functionality.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Iterable +from typing import Optional +from unittest.mock import MagicMock +from unittest.mock import PropertyMock + +import pytest +import requests + +from mlia.utils.download import download +from mlia.utils.download import DownloadArtifact + + +def response_mock( + content_length: Optional[str], content_chunks: Iterable[bytes] +) -> MagicMock: + """Mock response object.""" + mock = MagicMock(spec=requests.Response) + mock.__enter__.return_value = mock + + type(mock).headers = PropertyMock(return_value={"Content-Length": content_length}) + mock.iter_content.return_value = content_chunks + + return mock + + +@pytest.mark.parametrize("show_progress", [True, False]) +@pytest.mark.parametrize( + "content_length, content_chunks, label", + [ + [ + "5", + [bytes(range(5))], + "Downloading artifact", + ], + [ + "10", + [bytes(range(5)), bytes(range(5))], + None, + ], + [ + None, + [bytes(range(5))], + "Downlading no size", + ], + [ + "abc", + [bytes(range(5))], + "Downloading wrong size", + ], + ], +) +def test_download( + show_progress: bool, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + content_length: Optional[str], + content_chunks: Iterable[bytes], + label: Optional[str], +) -> None: + """Test function download.""" + monkeypatch.setattr( + "mlia.utils.download.requests.get", + MagicMock(return_value=response_mock(content_length, content_chunks)), + ) + + dest = tmp_path / "sample.bin" + download("some_url", dest, show_progress=show_progress, label=label) + + assert dest.is_file() + assert dest.read_bytes() == bytes( + byte for chunk in content_chunks for byte in chunk + ) + + +@pytest.mark.parametrize( + "content_length, content_chunks, sha256_hash, expected_error", + [ + [ + "10", + [bytes(range(10))], + "1f825aa2f0020ef7cf91dfa30da4668d791c5d4824fc8e41354b89ec05795ab3", + does_not_raise(), + ], + [ + "10", + [bytes(range(10))], + "bad_hash", + pytest.raises(ValueError, match="Digests do not match"), + ], + ], +) +def test_download_artifact_download_to( + monkeypatch: pytest.MonkeyPatch, + content_length: Optional[str], + content_chunks: Iterable[bytes], + sha256_hash: str, + expected_error: Any, + tmp_path: Path, +) -> None: + """Test artifact downloading.""" + monkeypatch.setattr( + "mlia.utils.download.requests.get", + MagicMock(return_value=response_mock(content_length, content_chunks)), + ) + + with expected_error: + artifact = DownloadArtifact( + "test_artifact", + "some_url", + "artifact_filename", + "1.0", + sha256_hash, + ) + + dest = artifact.download_to(tmp_path) + assert isinstance(dest, Path) + assert dest.name == "artifact_filename" + + +def test_download_artifact_unable_to_overwrite( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Test that download process cannot overwrite file.""" + monkeypatch.setattr( + "mlia.utils.download.requests.get", + MagicMock(return_value=response_mock("10", [bytes(range(10))])), + ) + + artifact = DownloadArtifact( + "test_artifact", + "some_url", + "artifact_filename", + "1.0", + "sha256_hash", + ) + + existing_file = tmp_path / "artifact_filename" + existing_file.touch() + + with pytest.raises(ValueError, match=f"{existing_file} already exists"): + artifact.download_to(tmp_path) diff --git a/tests/mlia/test_utils_filesystem.py b/tests/mlia/test_utils_filesystem.py new file mode 100644 index 0000000..4d8d955 --- /dev/null +++ b/tests/mlia/test_utils_filesystem.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the filesystem module.""" +import contextlib +import json +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from mlia.utils.filesystem import all_files_exist +from mlia.utils.filesystem import all_paths_valid +from mlia.utils.filesystem import copy_all +from mlia.utils.filesystem import get_mlia_resources +from mlia.utils.filesystem import get_profile +from mlia.utils.filesystem import get_profiles_data +from mlia.utils.filesystem import get_profiles_file +from mlia.utils.filesystem import get_supported_profile_names +from mlia.utils.filesystem import get_vela_config +from mlia.utils.filesystem import sha256 +from mlia.utils.filesystem import temp_directory +from mlia.utils.filesystem import temp_file + + +def test_get_mlia_resources() -> None: + """Test resources getter.""" + assert get_mlia_resources().is_dir() + + +def test_get_vela_config() -> None: + """Test Vela config files getter.""" + assert get_vela_config().is_file() + assert get_vela_config().name == "vela.ini" + + +def test_profiles_file() -> None: + """Test profiles file getter.""" + assert get_profiles_file().is_file() + assert get_profiles_file().name == "profiles.json" + + +def test_profiles_data() -> None: + """Test profiles data getter.""" + assert list(get_profiles_data().keys()) == [ + "ethos-u55-256", + "ethos-u55-128", + "ethos-u65-512", + ] + + +def test_profiles_data_wrong_format( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Test if profile data has wrong format.""" + wrong_profile_data = tmp_path / "bad.json" + with open(wrong_profile_data, "w", encoding="utf-8") as file: + json.dump([], file) + + monkeypatch.setattr( + "mlia.utils.filesystem.get_profiles_file", + MagicMock(return_value=wrong_profile_data), + ) + + with pytest.raises(Exception, match="Profiles data format is not valid"): + get_profiles_data() + + +def test_get_supported_profile_names() -> None: + """Test profile names getter.""" + assert list(get_supported_profile_names()) == [ + "ethos-u55-256", + "ethos-u55-128", + "ethos-u65-512", + ] + + +def test_get_profile() -> None: + """Test getting profile data.""" + assert get_profile("ethos-u55-256") == { + "target": "ethos-u55", + "mac": 256, + "system_config": "Ethos_U55_High_End_Embedded", + "memory_mode": "Shared_Sram", + } + + with pytest.raises(Exception, match="Unable to find target profile unknown"): + get_profile("unknown") + + +@pytest.mark.parametrize("raise_exception", [True, False]) +def test_temp_file(raise_exception: bool) -> None: + """Test temp_file context manager.""" + with contextlib.suppress(Exception): + with temp_file() as tmp_path: + assert tmp_path.is_file() + + if raise_exception: + raise Exception("Error!") + + assert not tmp_path.exists() + + +def test_sha256(tmp_path: Path) -> None: + """Test getting sha256 hash.""" + sample = tmp_path / "sample.txt" + + with open(sample, "w", encoding="utf-8") as file: + file.write("123") + + expected_hash = "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3" + assert sha256(sample) == expected_hash + + +def test_temp_dir_context_manager() -> None: + """Test context manager for temporary directories.""" + with temp_directory() as tmpdir: + assert isinstance(tmpdir, Path) + assert tmpdir.is_dir() + + assert not tmpdir.exists() + + +def test_all_files_exist(tmp_path: Path) -> None: + """Test function all_files_exist.""" + sample1 = tmp_path / "sample1.txt" + sample1.touch() + + sample2 = tmp_path / "sample2.txt" + sample2.touch() + + sample3 = tmp_path / "sample3.txt" + + assert all_files_exist([sample1, sample2]) is True + assert all_files_exist([sample1, sample2, sample3]) is False + + +def test_all_paths_valid(tmp_path: Path) -> None: + """Test function all_paths_valid.""" + sample = tmp_path / "sample.txt" + sample.touch() + + sample_dir = tmp_path / "sample_dir" + sample_dir.mkdir() + + unknown = tmp_path / "unknown.txt" + + assert all_paths_valid([sample, sample_dir]) is True + assert all_paths_valid([sample, sample_dir, unknown]) is False + + +def test_copy_all(tmp_path: Path) -> None: + """Test function copy_all.""" + sample = tmp_path / "sample1.txt" + sample.touch() + + sample_dir = tmp_path / "sample_dir" + sample_dir.mkdir() + + sample_nested_file = sample_dir / "sample_nested.txt" + sample_nested_file.touch() + + dest_dir = tmp_path / "dest" + copy_all(sample, sample_dir, dest=dest_dir) + + assert (dest_dir / sample.name).is_file() + assert (dest_dir / sample_nested_file.name).is_file() diff --git a/tests/mlia/test_utils_logging.py b/tests/mlia/test_utils_logging.py new file mode 100644 index 0000000..75ebceb --- /dev/null +++ b/tests/mlia/test_utils_logging.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the logging utility functions.""" +import logging +import sys +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Optional + +import pytest + +from mlia.cli.logging import create_log_handler + + +@pytest.mark.parametrize( + "file_path, stream, log_filter, delay, expected_error, expected_class", + [ + ( + "test.log", + None, + None, + True, + does_not_raise(), + logging.FileHandler, + ), + ( + None, + sys.stdout, + None, + None, + does_not_raise(), + logging.StreamHandler, + ), + ( + None, + None, + None, + None, + pytest.raises(Exception, match="Unable to create logging handler"), + None, + ), + ], +) +def test_create_log_handler( + file_path: Optional[Path], + stream: Optional[Any], + log_filter: Optional[logging.Filter], + delay: bool, + expected_error: Any, + expected_class: type, +) -> None: + """Test function test_create_log_handler.""" + with expected_error: + handler = create_log_handler( + file_path=file_path, + stream=stream, + log_level=logging.INFO, + log_format="%(name)s - %(message)s", + log_filter=log_filter, + delay=delay, + ) + assert isinstance(handler, expected_class) diff --git a/tests/mlia/test_utils_misc.py b/tests/mlia/test_utils_misc.py new file mode 100644 index 0000000..011d09e --- /dev/null +++ b/tests/mlia/test_utils_misc.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for misc util functions.""" +from unittest.mock import MagicMock + +import pytest + +from mlia.utils.misc import yes + + +@pytest.mark.parametrize( + "response, expected_result", + [ + ["Y", True], + ["y", True], + ["N", False], + ["n", False], + ], +) +def test_yes( + monkeypatch: pytest.MonkeyPatch, expected_result: bool, response: str +) -> None: + """Test yes function.""" + monkeypatch.setattr("builtins.input", MagicMock(return_value=response)) + assert yes("some_prompt") == expected_result diff --git a/tests/mlia/test_utils_proc.py b/tests/mlia/test_utils_proc.py new file mode 100644 index 0000000..8316ca5 --- /dev/null +++ b/tests/mlia/test_utils_proc.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the module utils/proc.""" +import signal +import subprocess +import time +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from mlia.utils.proc import CommandExecutor +from mlia.utils.proc import working_directory + + +class TestCommandExecutor: + """Tests for class CommandExecutor.""" + + @staticmethod + def test_execute() -> None: + """Test command execution.""" + executor = CommandExecutor() + + retcode, stdout, stderr = executor.execute(["echo", "hello world!"]) + assert retcode == 0 + assert stdout.decode().strip() == "hello world!" + assert stderr.decode() == "" + + @staticmethod + def test_submit() -> None: + """Test command submittion.""" + executor = CommandExecutor() + + running_command = executor.submit(["sleep", "10"]) + assert running_command.is_alive() is True + assert running_command.exit_code() is None + + running_command.kill() + for _ in range(3): + time.sleep(0.5) + if not running_command.is_alive(): + break + + assert running_command.is_alive() is False + assert running_command.exit_code() == -9 + + with pytest.raises(subprocess.CalledProcessError): + executor.execute(["sleep", "-1"]) + + @staticmethod + @pytest.mark.parametrize("wait", [True, False]) + def test_stop(wait: bool) -> None: + """Test command termination.""" + executor = CommandExecutor() + + running_command = executor.submit(["sleep", "10"]) + running_command.stop(wait=wait) + + if wait: + assert running_command.is_alive() is False + + @staticmethod + def test_unable_to_stop(monkeypatch: pytest.MonkeyPatch) -> None: + """Test when command could not be stopped.""" + running_command_mock = MagicMock() + running_command_mock.poll.return_value = None + + monkeypatch.setattr( + "mlia.utils.proc.subprocess.Popen", + MagicMock(return_value=running_command_mock), + ) + + with pytest.raises(Exception, match="Unable to stop running command"): + executor = CommandExecutor() + running_command = executor.submit(["sleep", "10"]) + + running_command.stop(num_of_attempts=1, interval=0.1) + + running_command_mock.send_signal.assert_called_once_with(signal.SIGINT) + + @staticmethod + def test_stop_after_several_attempts(monkeypatch: pytest.MonkeyPatch) -> None: + """Test when command could be stopped after several attempts.""" + running_command_mock = MagicMock() + running_command_mock.poll.side_effect = [None, 0] + + monkeypatch.setattr( + "mlia.utils.proc.subprocess.Popen", + MagicMock(return_value=running_command_mock), + ) + + executor = CommandExecutor() + running_command = executor.submit(["sleep", "10"]) + + running_command.stop(num_of_attempts=1, interval=0.1) + running_command_mock.send_signal.assert_called_once_with(signal.SIGINT) + + @staticmethod + def test_send_signal() -> None: + """Test sending signal.""" + executor = CommandExecutor() + running_command = executor.submit(["sleep", "10"]) + running_command.send_signal(signal.SIGINT) + + # wait a bit for a signal processing + time.sleep(1) + + assert running_command.is_alive() is False + assert running_command.exit_code() == -2 + + @staticmethod + @pytest.mark.parametrize( + "redirect_output, expected_output", [[True, "hello\n"], [False, ""]] + ) + def test_wait( + capsys: pytest.CaptureFixture, redirect_output: bool, expected_output: str + ) -> None: + """Test wait completion functionality.""" + executor = CommandExecutor() + + running_command = executor.submit(["echo", "hello"]) + running_command.wait(redirect_output=redirect_output) + + out, _ = capsys.readouterr() + assert out == expected_output + + +@pytest.mark.parametrize( + "should_exist, create_dir", + [ + [True, False], + [False, True], + ], +) +def test_working_directory_context_manager( + tmp_path: Path, should_exist: bool, create_dir: bool +) -> None: + """Test working_directory context manager.""" + prev_wd = Path.cwd() + + working_dir = tmp_path / "work_dir" + if should_exist: + working_dir.mkdir() + + with working_directory(working_dir, create_dir=create_dir) as current_working_dir: + assert current_working_dir.is_dir() + assert Path.cwd() == current_working_dir + + assert Path.cwd() == prev_wd diff --git a/tests/mlia/test_utils_types.py b/tests/mlia/test_utils_types.py new file mode 100644 index 0000000..4909efe --- /dev/null +++ b/tests/mlia/test_utils_types.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the types related utility functions.""" +from typing import Any +from typing import Iterable +from typing import Optional + +import pytest + +from mlia.utils.types import is_list_of +from mlia.utils.types import is_number +from mlia.utils.types import only_one_selected +from mlia.utils.types import parse_int + + +@pytest.mark.parametrize( + "value, expected_result", + [ + ["", False], + ["abc", False], + ["123", True], + ["123.1", True], + ["-123", True], + ["-123.1", True], + ["0", True], + ["1.e10", True], + ], +) +def test_is_number(value: str, expected_result: bool) -> None: + """Test function is_number.""" + assert is_number(value) == expected_result + + +@pytest.mark.parametrize( + "data, cls, elem_num, expected_result", + [ + [(1, 2), int, 2, True], + [[1, 2], int, 2, True], + [[1, 2], int, 3, False], + [["1", "2", "3"], str, None, True], + [["1", "2", "3"], int, None, False], + ], +) +def test_is_list( + data: Any, cls: type, elem_num: Optional[int], expected_result: bool +) -> None: + """Test function is_list.""" + assert is_list_of(data, cls, elem_num) == expected_result + + +@pytest.mark.parametrize( + "options, expected_result", + [ + [[True], True], + [[False], False], + [[True, True, False, False], False], + ], +) +def test_only_one_selected(options: Iterable[bool], expected_result: bool) -> None: + """Test function only_one_selected.""" + assert only_one_selected(*options) == expected_result + + +@pytest.mark.parametrize( + "value, default, expected_int", + [ + ["1", None, 1], + ["abc", 123, 123], + [None, None, None], + [None, 11, 11], + ], +) +def test_parse_int( + value: Any, default: Optional[int], expected_int: Optional[int] +) -> None: + """Test function parse_int.""" + assert parse_int(value, default) == expected_int diff --git a/tests/mlia/utils/__init__.py b/tests/mlia/utils/__init__.py new file mode 100644 index 0000000..27166ef --- /dev/null +++ b/tests/mlia/utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Test utils module.""" diff --git a/tests/mlia/utils/common.py b/tests/mlia/utils/common.py new file mode 100644 index 0000000..4313cde --- /dev/null +++ b/tests/mlia/utils/common.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Common test utils module.""" +from typing import Tuple + +import numpy as np +import tensorflow as tf + + +def get_dataset() -> Tuple[np.array, np.array]: + """Return sample dataset.""" + mnist = tf.keras.datasets.mnist + (x_train, y_train), _ = mnist.load_data() + x_train = x_train / 255.0 + + # Use subset of 60000 examples to keep unit test speed fast. + x_train = x_train[0:1] + y_train = y_train[0:1] + + return x_train, y_train + + +def train_model(model: tf.keras.Model) -> None: + """Train model using sample dataset.""" + num_epochs = 1 + + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"]) + + x_train, y_train = get_dataset() + + model.fit(x_train, y_train, epochs=num_epochs) diff --git a/tests/mlia/utils/logging.py b/tests/mlia/utils/logging.py new file mode 100644 index 0000000..d223fb2 --- /dev/null +++ b/tests/mlia/utils/logging.py @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Utils for logging.""" +import logging + + +def clear_loggers() -> None: + """Close the log handlers.""" + for _, logger in logging.Logger.manager.loggerDict.items(): + if not isinstance(logger, logging.PlaceHolder): + for handler in logger.handlers: + handler.close() + logger.removeHandler(handler) -- cgit v1.2.1