diff options
-rw-r--r-- | pyproject.toml | 5 | ||||
-rw-r--r-- | setup.cfg | 3 | ||||
-rw-r--r-- | src/mlia/backend/application.py | 23 | ||||
-rw-r--r-- | src/mlia/backend/common.py | 35 | ||||
-rw-r--r-- | src/mlia/backend/config.py | 26 | ||||
-rw-r--r-- | src/mlia/backend/controller.py | 134 | ||||
-rw-r--r-- | src/mlia/backend/execution.py | 524 | ||||
-rw-r--r-- | src/mlia/backend/fs.py | 36 | ||||
-rw-r--r-- | src/mlia/backend/manager.py | 177 | ||||
-rw-r--r-- | src/mlia/backend/output_consumer.py | 66 | ||||
-rw-r--r-- | src/mlia/backend/output_parser.py | 176 | ||||
-rw-r--r-- | src/mlia/backend/proc.py | 89 | ||||
-rw-r--r-- | src/mlia/backend/protocol.py | 325 | ||||
-rw-r--r-- | src/mlia/backend/source.py | 18 | ||||
-rw-r--r-- | src/mlia/backend/system.py | 229 | ||||
-rw-r--r-- | src/mlia/core/reporting.py | 2 | ||||
-rw-r--r-- | src/mlia/nn/tensorflow/tflite_metrics.py | 33 | ||||
-rw-r--r-- | src/mlia/resources/backend_configs/systems/SYSTEMS.txt (renamed from src/mlia/resources/aiet/systems/SYSTEMS.txt) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backend_configs/systems/corstone-300-vht/backend-config.json (renamed from src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json) | 8 | ||||
-rw-r--r-- | src/mlia/resources/backend_configs/systems/corstone-300-vht/backend-config.json.license (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backend_configs/systems/corstone-300/backend-config.json (renamed from src/mlia/resources/aiet/systems/corstone-300/aiet-config.json) | 8 | ||||
-rw-r--r-- | src/mlia/resources/backend_configs/systems/corstone-300/backend-config.json.license (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backend_configs/systems/corstone-310-vht/backend-config.json (renamed from src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json) | 4 | ||||
-rw-r--r-- | src/mlia/resources/backend_configs/systems/corstone-310-vht/backend-config.json.license (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backend_configs/systems/corstone-310/backend-config.json (renamed from src/mlia/resources/aiet/systems/corstone-310/aiet-config.json) | 4 | ||||
-rw-r--r-- | src/mlia/resources/backend_configs/systems/corstone-310/backend-config.json.license (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/.gitignore | 6 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/APPLICATIONS.txt (renamed from src/mlia/resources/aiet/applications/APPLICATIONS.txt) | 2 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/backend-config.json (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json) | 1 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/backend-config.json.license (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf) | bin | 426496 -> 426496 bytes | |||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/backend-config.json (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json) | 1 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/backend-config.json.license (renamed from src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf) | bin | 426544 -> 426544 bytes | |||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/backend-config.json (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json) | 1 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/backend-config.json.license (renamed from src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf) | bin | 2524028 -> 2524028 bytes | |||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/backend-config.json (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json) | 1 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/backend-config.json.license (renamed from src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf) | bin | 426488 -> 426488 bytes | |||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/backend-config.json (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json) | 1 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/backend-config.json.license (renamed from src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf) | bin | 426536 -> 426536 bytes | |||
-rw-r--r-- | src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license (renamed from src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license) | 0 | ||||
-rw-r--r-- | src/mlia/tools/metadata/corstone.py | 28 | ||||
-rw-r--r-- | src/mlia/utils/filesystem.py | 18 | ||||
-rw-r--r-- | src/mlia/utils/proc.py | 152 | ||||
-rw-r--r-- | tests/conftest.py | 103 | ||||
-rw-r--r-- | tests/mlia/__init__.py | 3 | ||||
-rw-r--r-- | tests/mlia/conftest.py | 111 | ||||
-rw-r--r-- | tests/mlia/test_backend_controller.py | 160 | ||||
-rw-r--r-- | tests/mlia/test_backend_execution.py | 518 | ||||
-rw-r--r-- | tests/mlia/test_backend_output_parser.py | 152 | ||||
-rw-r--r-- | tests/mlia/test_backend_protocol.py | 231 | ||||
-rw-r--r-- | tests/mlia/test_utils_proc.py | 149 | ||||
-rw-r--r-- | tests/test_api.py (renamed from tests/mlia/test_api.py) | 0 | ||||
-rw-r--r-- | tests/test_backend_application.py (renamed from tests/mlia/test_backend_application.py) | 48 | ||||
-rw-r--r-- | tests/test_backend_common.py (renamed from tests/mlia/test_backend_common.py) | 10 | ||||
-rw-r--r-- | tests/test_backend_execution.py | 203 | ||||
-rw-r--r-- | tests/test_backend_fs.py (renamed from tests/mlia/test_backend_fs.py) | 34 | ||||
-rw-r--r-- | tests/test_backend_manager.py (renamed from tests/mlia/test_backend_manager.py) | 190 | ||||
-rw-r--r-- | tests/test_backend_output_consumer.py | 99 | ||||
-rw-r--r-- | tests/test_backend_proc.py (renamed from tests/mlia/test_backend_proc.py) | 69 | ||||
-rw-r--r-- | tests/test_backend_source.py (renamed from tests/mlia/test_backend_source.py) | 4 | ||||
-rw-r--r-- | tests/test_backend_system.py (renamed from tests/mlia/test_backend_system.py) | 204 | ||||
-rw-r--r-- | tests/test_cli_commands.py (renamed from tests/mlia/test_cli_commands.py) | 0 | ||||
-rw-r--r-- | tests/test_cli_config.py (renamed from tests/mlia/test_cli_config.py) | 0 | ||||
-rw-r--r-- | tests/test_cli_helpers.py (renamed from tests/mlia/test_cli_helpers.py) | 0 | ||||
-rw-r--r-- | tests/test_cli_logging.py (renamed from tests/mlia/test_cli_logging.py) | 2 | ||||
-rw-r--r-- | tests/test_cli_main.py (renamed from tests/mlia/test_cli_main.py) | 2 | ||||
-rw-r--r-- | tests/test_cli_options.py (renamed from tests/mlia/test_cli_options.py) | 0 | ||||
-rw-r--r-- | tests/test_core_advice_generation.py (renamed from tests/mlia/test_core_advice_generation.py) | 0 | ||||
-rw-r--r-- | tests/test_core_advisor.py (renamed from tests/mlia/test_core_advisor.py) | 0 | ||||
-rw-r--r-- | tests/test_core_context.py (renamed from tests/mlia/test_core_context.py) | 0 | ||||
-rw-r--r-- | tests/test_core_data_analysis.py (renamed from tests/mlia/test_core_data_analysis.py) | 0 | ||||
-rw-r--r-- | tests/test_core_events.py (renamed from tests/mlia/test_core_events.py) | 0 | ||||
-rw-r--r-- | tests/test_core_helpers.py (renamed from tests/mlia/test_core_helpers.py) | 0 | ||||
-rw-r--r-- | tests/test_core_mixins.py (renamed from tests/mlia/test_core_mixins.py) | 0 | ||||
-rw-r--r-- | tests/test_core_performance.py (renamed from tests/mlia/test_core_performance.py) | 0 | ||||
-rw-r--r-- | tests/test_core_reporting.py (renamed from tests/mlia/test_core_reporting.py) | 0 | ||||
-rw-r--r-- | tests/test_core_workflow.py (renamed from tests/mlia/test_core_workflow.py) | 0 | ||||
-rw-r--r-- | tests/test_devices_ethosu_advice_generation.py (renamed from tests/mlia/test_devices_ethosu_advice_generation.py) | 0 | ||||
-rw-r--r-- | tests/test_devices_ethosu_advisor.py (renamed from tests/mlia/test_devices_ethosu_advisor.py) | 0 | ||||
-rw-r--r-- | tests/test_devices_ethosu_config.py (renamed from tests/mlia/test_devices_ethosu_config.py) | 0 | ||||
-rw-r--r-- | tests/test_devices_ethosu_data_analysis.py (renamed from tests/mlia/test_devices_ethosu_data_analysis.py) | 0 | ||||
-rw-r--r-- | tests/test_devices_ethosu_data_collection.py (renamed from tests/mlia/test_devices_ethosu_data_collection.py) | 0 | ||||
-rw-r--r-- | tests/test_devices_ethosu_performance.py (renamed from tests/mlia/test_devices_ethosu_performance.py) | 0 | ||||
-rw-r--r-- | tests/test_devices_ethosu_reporters.py (renamed from tests/mlia/test_devices_ethosu_reporters.py) | 0 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_config.py (renamed from tests/mlia/test_nn_tensorflow_config.py) | 0 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_optimizations_clustering.py (renamed from tests/mlia/test_nn_tensorflow_optimizations_clustering.py) | 4 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_optimizations_pruning.py (renamed from tests/mlia/test_nn_tensorflow_optimizations_pruning.py) | 4 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_optimizations_select.py (renamed from tests/mlia/test_nn_tensorflow_optimizations_select.py) | 0 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_tflite_metrics.py (renamed from tests/mlia/test_nn_tensorflow_tflite_metrics.py) | 16 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_utils.py (renamed from tests/mlia/test_nn_tensorflow_utils.py) | 0 | ||||
-rw-r--r-- | tests/test_resources/application_config.json (renamed from tests/mlia/test_resources/application_config.json) | 2 | ||||
-rw-r--r-- | tests/test_resources/application_config.json.license (renamed from tests/mlia/test_resources/application_config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/application1/backend-config.json (renamed from tests/mlia/test_resources/backends/applications/application1/aiet-config.json) | 1 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/application1/backend-config.json.license (renamed from tests/mlia/test_resources/backends/applications/application1/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/application2/backend-config.json (renamed from tests/mlia/test_resources/backends/applications/application2/aiet-config.json) | 1 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/application2/backend-config.json.license (renamed from tests/mlia/test_resources/backends/applications/application2/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/application3/readme.txt (renamed from tests/mlia/test_resources/backends/applications/application3/readme.txt) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/application4/backend-config.json (renamed from tests/mlia/test_resources/backends/applications/application4/aiet-config.json) | 17 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/application4/backend-config.json.license (renamed from tests/mlia/test_resources/backends/applications/application4/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/application4/hello_app.txt (renamed from tests/mlia/test_resources/backends/applications/application4/hello_app.txt) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/application5/backend-config.json (renamed from tests/mlia/test_resources/backends/applications/application5/aiet-config.json) | 28 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/application5/backend-config.json.license (renamed from tests/mlia/test_resources/backends/applications/application5/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/application6/backend-config.json (renamed from tests/mlia/test_resources/backends/applications/application6/aiet-config.json) | 1 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/application6/backend-config.json.license (renamed from tests/mlia/test_resources/backends/applications/application6/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/applications/readme.txt (renamed from tests/mlia/test_resources/backends/applications/readme.txt) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/systems/system1/backend-config.json (renamed from tests/mlia/test_resources/backends/systems/system1/aiet-config.json) | 11 | ||||
-rw-r--r-- | tests/test_resources/backends/systems/system1/backend-config.json.license (renamed from tests/mlia/test_resources/backends/systems/system1/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/systems/system1/system_artifact/dummy.txt (renamed from tests/mlia/test_resources/backends/systems/system1/system_artifact/dummy.txt) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/systems/system2/backend-config.json (renamed from tests/mlia/test_resources/backends/systems/system2/aiet-config.json) | 8 | ||||
-rw-r--r-- | tests/test_resources/backends/systems/system2/backend-config.json.license (renamed from tests/mlia/test_resources/backends/systems/system2/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/systems/system3/readme.txt (renamed from tests/mlia/test_resources/backends/systems/system3/readme.txt) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/systems/system4/backend-config.json (renamed from tests/mlia/test_resources/backends/systems/system4/aiet-config.json) | 8 | ||||
-rw-r--r-- | tests/test_resources/backends/systems/system4/backend-config.json.license (renamed from tests/mlia/test_resources/backends/systems/system4/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/backends/systems/system6/backend-config.json (renamed from tests/mlia/test_resources/backends/systems/system6/aiet-config.json) | 4 | ||||
-rw-r--r-- | tests/test_resources/backends/systems/system6/backend-config.json.license (renamed from tests/mlia/test_resources/backends/systems/system6/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/hello_world.json (renamed from tests/mlia/test_resources/hello_world.json) | 1 | ||||
-rw-r--r-- | tests/test_resources/hello_world.json.license (renamed from tests/mlia/test_resources/hello_world.json.license) | 0 | ||||
-rwxr-xr-x | tests/test_resources/scripts/test_backend_run (renamed from tests/mlia/test_resources/scripts/test_backend_run) | 0 | ||||
-rw-r--r-- | tests/test_resources/scripts/test_backend_run_script.sh (renamed from tests/mlia/test_resources/scripts/test_backend_run_script.sh) | 0 | ||||
-rw-r--r-- | tests/test_resources/various/applications/application_with_empty_config/backend-config.json (renamed from tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json) | 0 | ||||
-rw-r--r-- | tests/test_resources/various/applications/application_with_empty_config/backend-config.json.license (renamed from tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/various/applications/application_with_valid_config/backend-config.json (renamed from tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json) | 6 | ||||
-rw-r--r-- | tests/test_resources/various/applications/application_with_valid_config/backend-config.json.license (renamed from tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/various/applications/application_with_wrong_config1/backend-config.json (renamed from tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json) | 0 | ||||
-rw-r--r-- | tests/test_resources/various/applications/application_with_wrong_config1/backend-config.json.license (renamed from tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/various/applications/application_with_wrong_config2/backend-config.json (renamed from tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json) | 6 | ||||
-rw-r--r-- | tests/test_resources/various/applications/application_with_wrong_config2/backend-config.json.license (renamed from tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/various/applications/application_with_wrong_config3/backend-config.json (renamed from tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json) | 6 | ||||
-rw-r--r-- | tests/test_resources/various/applications/application_with_wrong_config3/backend-config.json.license (renamed from tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/various/systems/system_with_empty_config/backend-config.json (renamed from tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json) | 0 | ||||
-rw-r--r-- | tests/test_resources/various/systems/system_with_empty_config/backend-config.json.license (renamed from tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/various/systems/system_with_valid_config/backend-config.json (renamed from tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json) | 4 | ||||
-rw-r--r-- | tests/test_resources/various/systems/system_with_valid_config/backend-config.json.license (renamed from tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json.license) | 0 | ||||
-rw-r--r-- | tests/test_resources/vela/sample_vela.ini (renamed from tests/mlia/test_resources/vela/sample_vela.ini) | 0 | ||||
-rw-r--r-- | tests/test_tools_metadata_common.py (renamed from tests/mlia/test_tools_metadata_common.py) | 0 | ||||
-rw-r--r-- | tests/test_tools_metadata_corstone.py (renamed from tests/mlia/test_tools_metadata_corstone.py) | 0 | ||||
-rw-r--r-- | tests/test_tools_vela_wrapper.py (renamed from tests/mlia/test_tools_vela_wrapper.py) | 2 | ||||
-rw-r--r-- | tests/test_utils_console.py (renamed from tests/mlia/test_utils_console.py) | 0 | ||||
-rw-r--r-- | tests/test_utils_download.py (renamed from tests/mlia/test_utils_download.py) | 0 | ||||
-rw-r--r-- | tests/test_utils_filesystem.py (renamed from tests/mlia/test_utils_filesystem.py) | 25 | ||||
-rw-r--r-- | tests/test_utils_logging.py (renamed from tests/mlia/test_utils_logging.py) | 0 | ||||
-rw-r--r-- | tests/test_utils_misc.py (renamed from tests/mlia/test_utils_misc.py) | 0 | ||||
-rw-r--r-- | tests/test_utils_types.py (renamed from tests/mlia/test_utils_types.py) | 0 | ||||
-rw-r--r-- | tests/utils/__init__.py (renamed from tests/mlia/utils/__init__.py) | 0 | ||||
-rw-r--r-- | tests/utils/common.py (renamed from tests/mlia/utils/common.py) | 0 | ||||
-rw-r--r-- | tests/utils/logging.py (renamed from tests/mlia/utils/logging.py) | 0 |
154 files changed, 891 insertions, 3692 deletions
diff --git a/pyproject.toml b/pyproject.toml index 05363d8..1dcbf21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,7 @@ max-attributes=10 # Provide basic compatibility with black disable = [ - "wrong-import-order", - "consider-using-f-string" # C0209 + "wrong-import-order" ] enable = [ @@ -71,8 +70,6 @@ check_untyped_defs = true [[tool.mypy.overrides]] module = [ "pkg_resources", - "paramiko", "requests", - "filelock" ] ignore_missing_imports = true @@ -35,9 +35,6 @@ install_requires = requests rich sh - paramiko - filelock - psutil [options.packages.find] where = src diff --git a/src/mlia/backend/application.py b/src/mlia/backend/application.py index eb85212..4b04324 100644 --- a/src/mlia/backend/application.py +++ b/src/mlia/backend/application.py @@ -11,10 +11,9 @@ from typing import Optional from mlia.backend.common import Backend from mlia.backend.common import ConfigurationException -from mlia.backend.common import DataPaths from mlia.backend.common import get_backend_configs from mlia.backend.common import get_backend_directories -from mlia.backend.common import load_application_or_tool_configs +from mlia.backend.common import load_application_configs from mlia.backend.common import load_config from mlia.backend.common import remove_backend from mlia.backend.config import ApplicationConfig @@ -75,7 +74,7 @@ def install_application(source_path: Path) -> None: if already_installed: names = {application.name for application in already_installed} raise ConfigurationException( - "Applications [{}] are already installed".format(",".join(names)) + f"Applications [{','.join(names)}] are already installed." ) create_destination_and_install(source, get_backends_path("applications")) @@ -105,7 +104,6 @@ class Application(Backend): super().__init__(config) self.supported_systems = config.get("supported_systems", []) - self.deploy_data = config.get("deploy_data", []) def __eq__(self, other: object) -> bool: """Overload operator ==.""" @@ -122,21 +120,6 @@ class Application(Backend): """Check if the application can run on the system passed as argument.""" return system_name in self.supported_systems - def get_deploy_data(self) -> List[DataPaths]: - """Validate and return data specified in the config file.""" - if self.config_location is None: - raise ConfigurationException( - "Unable to get application {} config location".format(self.name) - ) - - deploy_data = [] - for item in self.deploy_data: - src, dst = item - src_full_path = self.config_location / src - assert src_full_path.exists(), "{} does not exists".format(src_full_path) - deploy_data.append(DataPaths(src_full_path, dst)) - return deploy_data - def get_details(self) -> Dict[str, Any]: """Return dictionary with information about the Application instance.""" output = { @@ -180,7 +163,7 @@ def load_applications(config: ExtendedApplicationConfig) -> List[Application]: supported systems. For each supported system this function will return separate Application instance with appropriate configuration. """ - configs = load_application_or_tool_configs(config, ApplicationConfig) + configs = load_application_configs(config, ApplicationConfig) applications = [Application(cfg) for cfg in configs] for application in applications: application.remove_unused_params() diff --git a/src/mlia/backend/common.py b/src/mlia/backend/common.py index 2bbb9d3..e61d6b6 100644 --- a/src/mlia/backend/common.py +++ b/src/mlia/backend/common.py @@ -33,7 +33,7 @@ from mlia.backend.fs import remove_resource from mlia.backend.fs import ResourceType -BACKEND_CONFIG_FILE: Final[str] = "aiet-config.json" +BACKEND_CONFIG_FILE: Final[str] = "backend-config.json" class ConfigurationException(Exception): @@ -126,10 +126,6 @@ class Backend(ABC): self.description = config.get("description", "") self.config_location = config.get("config_location") self.variables = config.get("variables", {}) - self.build_dir = config.get("build_dir") - self.lock = config.get("lock", False) - if self.build_dir: - self.build_dir = self._substitute_variables(self.build_dir) self.annotations = config.get("annotations", {}) self._parse_commands_and_params(config) @@ -145,7 +141,7 @@ class Backend(ABC): command = self.commands.get(command_name) if not command: - raise AttributeError("Unknown command: '{}'".format(command_name)) + raise AttributeError(f"Unknown command: '{command_name}'") # Iterate over all available parameters until we have a match. for param in command.params: @@ -209,7 +205,7 @@ class Backend(ABC): def var_value(match: Match) -> str: var_name = match["var_name"] if var_name not in self.variables: - raise ConfigurationException("Unknown variable {}".format(var_name)) + raise ConfigurationException(f"Unknown variable {var_name}") return self.variables[var_name] @@ -312,7 +308,7 @@ class Backend(ABC): command = self.commands.get(command_name) if not command: raise ConfigurationException( - "Command '{}' could not be found.".format(command_name) + f"Command '{command_name}' could not be found." ) commands_to_run = [] @@ -394,7 +390,7 @@ class Command: if repeated_aliases: raise ConfigurationException( - "Non unique aliases {}".format(", ".join(repeated_aliases)) + f"Non-unique aliases {', '.join(repeated_aliases)}" ) both_name_and_alias = [ @@ -404,9 +400,8 @@ class Command: ] if both_name_and_alias: raise ConfigurationException( - "Aliases {} could not be used as parameter name".format( - ", ".join(both_name_and_alias) - ) + f"Aliases {', '.join(both_name_and_alias)} could not be used " + "as parameter name." ) def get_details(self) -> Dict: @@ -449,12 +444,12 @@ def resolve_all_parameters( return str_val -def load_application_or_tool_configs( +def load_application_configs( config: Any, config_type: Type[Any], is_system_required: bool = True, ) -> Any: - """Get one config for each system supported by the application/tool. + """Get one config for each system supported by the application. The configuration could contain different parameters/commands for different supported systems. For each supported system this function will return separate @@ -501,15 +496,13 @@ def load_application_or_tool_configs( if params_default and params_tool: if any(not p.get("alias") for p in params_default): raise ConfigurationException( - "Default parameters for command {} should have aliases".format( - command_name - ) + f"Default parameters for command {command_name} " + "should have aliases" ) if any(not p.get("alias") for p in params_tool): raise ConfigurationException( - "{} parameters for command {} should have aliases".format( - system_name, command_name - ) + f"{system_name} parameters for command {command_name} " + "should have aliases." ) merged_by_alias = { @@ -519,8 +512,6 @@ def load_application_or_tool_configs( params[command_name] = list(merged_by_alias.values()) merged_config["user_params"] = params - merged_config["build_dir"] = system.get("build_dir", config.get("build_dir")) - merged_config["lock"] = system.get("lock", config.get("lock", False)) merged_config["variables"] = { **config.get("variables", {}), **system.get("variables", {}), diff --git a/src/mlia/backend/config.py b/src/mlia/backend/config.py index 657adef..9a56fa9 100644 --- a/src/mlia/backend/config.py +++ b/src/mlia/backend/config.py @@ -4,9 +4,7 @@ from pathlib import Path from typing import Dict from typing import List -from typing import Literal from typing import Optional -from typing import Tuple from typing import TypedDict from typing import Union @@ -29,9 +27,7 @@ class ExecutionConfig(TypedDict, total=False): commands: Dict[str, List[str]] user_params: UserParamsConfig - build_dir: str variables: Dict[str, str] - lock: bool class NamedExecutionConfig(ExecutionConfig): @@ -53,39 +49,17 @@ class ApplicationConfig(BaseBackendConfig, total=False): """Application configuration.""" supported_systems: List[str] - deploy_data: List[Tuple[str, str]] class ExtendedApplicationConfig(BaseBackendConfig, total=False): """Extended application configuration.""" supported_systems: List[NamedExecutionConfig] - deploy_data: List[Tuple[str, str]] - - -class ProtocolConfig(TypedDict, total=False): - """Protocol config.""" - - protocol: Literal["local", "ssh"] - - -class SSHConfig(ProtocolConfig, total=False): - """SSH configuration.""" - - username: str - password: str - hostname: str - port: str - - -class LocalProtocolConfig(ProtocolConfig, total=False): - """Local protocol config.""" class SystemConfig(BaseBackendConfig, total=False): """System configuration.""" - data_transfer: Union[SSHConfig, LocalProtocolConfig] reporting: Dict[str, Dict] diff --git a/src/mlia/backend/controller.py b/src/mlia/backend/controller.py deleted file mode 100644 index f1b68a9..0000000 --- a/src/mlia/backend/controller.py +++ /dev/null @@ -1,134 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Controller backend module.""" -import time -from pathlib import Path -from typing import List -from typing import Optional -from typing import Tuple - -import psutil -import sh - -from mlia.backend.common import ConfigurationException -from mlia.backend.fs import read_file_as_string -from mlia.backend.proc import execute_command -from mlia.backend.proc import get_stdout_stderr_paths -from mlia.backend.proc import read_process_info -from mlia.backend.proc import save_process_info -from mlia.backend.proc import terminate_command -from mlia.backend.proc import terminate_external_process - - -class SystemController: - """System controller class.""" - - def __init__(self) -> None: - """Create new instance of service controller.""" - self.cmd: Optional[sh.RunningCommand] = None - self.out_path: Optional[Path] = None - self.err_path: Optional[Path] = None - - def before_start(self) -> None: - """Run actions before system start.""" - - def after_start(self) -> None: - """Run actions after system start.""" - - def start(self, commands: List[str], cwd: Path) -> None: - """Start system.""" - if not isinstance(cwd, Path) or not cwd.is_dir(): - raise ConfigurationException("Wrong working directory {}".format(cwd)) - - if len(commands) != 1: - raise ConfigurationException("System should have only one command to run") - - startup_command = commands[0] - if not startup_command: - raise ConfigurationException("No startup command provided") - - self.before_start() - - self.out_path, self.err_path = get_stdout_stderr_paths(startup_command) - - self.cmd = execute_command( - startup_command, - cwd, - bg=True, - out=str(self.out_path), - err=str(self.err_path), - ) - - self.after_start() - - def stop( - self, wait: bool = False, wait_period: float = 0.5, number_of_attempts: int = 20 - ) -> None: - """Stop system.""" - if self.cmd is not None and self.is_running(): - terminate_command(self.cmd, wait, wait_period, number_of_attempts) - - def is_running(self) -> bool: - """Check if underlying process is running.""" - return self.cmd is not None and self.cmd.is_alive() - - def get_output(self) -> Tuple[str, str]: - """Return application output.""" - if self.cmd is None or self.out_path is None or self.err_path is None: - return ("", "") - - return (read_file_as_string(self.out_path), read_file_as_string(self.err_path)) - - -class SystemControllerSingleInstance(SystemController): - """System controller with support of system's single instance.""" - - def __init__(self, pid_file_path: Optional[Path] = None) -> None: - """Create new instance of the service controller.""" - super().__init__() - self.pid_file_path = pid_file_path - - def before_start(self) -> None: - """Run actions before system start.""" - self._check_if_previous_instance_is_running() - - def after_start(self) -> None: - """Run actions after system start.""" - self._save_process_info() - - def _check_if_previous_instance_is_running(self) -> None: - """Check if another instance of the system is running.""" - process_info = read_process_info(self._pid_file()) - - for item in process_info: - try: - process = psutil.Process(item.pid) - same_process = ( - process.name() == item.name - and process.exe() == item.executable - and process.cwd() == item.cwd - ) - if same_process: - print( - "Stopping previous instance of the system [{}]".format(item.pid) - ) - terminate_external_process(process) - except psutil.NoSuchProcess: - pass - - def _save_process_info(self, wait_period: float = 2) -> None: - """Save information about system's processes.""" - if self.cmd is None or not self.is_running(): - return - - # give some time for the system to start - time.sleep(wait_period) - - save_process_info(self.cmd.process.pid, self._pid_file()) - - def _pid_file(self) -> Path: - """Return path to file which is used for saving process info.""" - if not self.pid_file_path: - raise Exception("No pid file path presented") - - return self.pid_file_path diff --git a/src/mlia/backend/execution.py b/src/mlia/backend/execution.py index 749ccdb..5340a47 100644 --- a/src/mlia/backend/execution.py +++ b/src/mlia/backend/execution.py @@ -1,167 +1,49 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Application execution module.""" -import itertools -import json -import random +import logging import re -import string -import sys -import time -from collections import defaultdict -from contextlib import contextmanager -from contextlib import ExitStack -from pathlib import Path -from typing import Any -from typing import Callable from typing import cast -from typing import ContextManager -from typing import Dict -from typing import Generator from typing import List from typing import Optional -from typing import Sequence from typing import Tuple -from typing import TypedDict - -from filelock import FileLock -from filelock import Timeout from mlia.backend.application import Application from mlia.backend.application import get_application from mlia.backend.common import Backend from mlia.backend.common import ConfigurationException -from mlia.backend.common import DataPaths from mlia.backend.common import Param -from mlia.backend.common import parse_raw_parameter -from mlia.backend.common import resolve_all_parameters -from mlia.backend.fs import recreate_directory -from mlia.backend.fs import remove_directory -from mlia.backend.fs import valid_for_filename -from mlia.backend.output_parser import Base64OutputParser -from mlia.backend.output_parser import OutputParser -from mlia.backend.output_parser import RegexOutputParser -from mlia.backend.proc import run_and_wait -from mlia.backend.system import ControlledSystem from mlia.backend.system import get_system -from mlia.backend.system import StandaloneSystem from mlia.backend.system import System +logger = logging.getLogger(__name__) + class AnotherInstanceIsRunningException(Exception): """Concurrent execution error.""" -class ConnectionException(Exception): - """Connection exception.""" - - -class ExecutionParams(TypedDict, total=False): - """Execution parameters.""" - - disable_locking: bool - unique_build_dir: bool - - -class ExecutionContext: +class ExecutionContext: # pylint: disable=too-few-public-methods """Command execution context.""" - # pylint: disable=too-many-arguments,too-many-instance-attributes def __init__( self, app: Application, app_params: List[str], - system: Optional[System], + system: System, system_params: List[str], - custom_deploy_data: Optional[List[DataPaths]] = None, - execution_params: Optional[ExecutionParams] = None, - report_file: Optional[Path] = None, ): """Init execution context.""" self.app = app self.app_params = app_params - self.custom_deploy_data = custom_deploy_data or [] self.system = system self.system_params = system_params - self.execution_params = execution_params or ExecutionParams() - self.report_file = report_file - - self.reporter: Optional[Reporter] - if self.report_file: - # Create reporter with output parsers - parsers: List[OutputParser] = [] - if system and system.reporting: - # Add RegexOutputParser, if it is configured in the system - parsers.append(RegexOutputParser("system", system.reporting["regex"])) - # Add Base64 parser for applications - parsers.append(Base64OutputParser("application")) - self.reporter = Reporter(parsers=parsers) - else: - self.reporter = None # No reporter needed. self.param_resolver = ParamResolver(self) - self._resolved_build_dir: Optional[Path] = None self.stdout: Optional[bytearray] = None self.stderr: Optional[bytearray] = None - @property - def is_deploy_needed(self) -> bool: - """Check if application requires data deployment.""" - return len(self.app.get_deploy_data()) > 0 or len(self.custom_deploy_data) > 0 - - @property - def is_locking_required(self) -> bool: - """Return true if any form of locking required.""" - return not self._disable_locking() and ( - self.app.lock or (self.system is not None and self.system.lock) - ) - - @property - def is_build_required(self) -> bool: - """Return true if application build required.""" - return "build" in self.app.commands - - @property - def is_unique_build_dir_required(self) -> bool: - """Return true if unique build dir required.""" - return self.execution_params.get("unique_build_dir", False) - - def build_dir(self) -> Path: - """Return resolved application build dir.""" - if self._resolved_build_dir is not None: - return self._resolved_build_dir - - if ( - not isinstance(self.app.config_location, Path) - or not self.app.config_location.is_dir() - ): - raise ConfigurationException( - "Application {} has wrong config location".format(self.app.name) - ) - - _build_dir = self.app.build_dir - if _build_dir: - _build_dir = resolve_all_parameters(_build_dir, self.param_resolver) - - if not _build_dir: - raise ConfigurationException( - "No build directory defined for the app {}".format(self.app.name) - ) - - if self.is_unique_build_dir_required: - random_suffix = "".join( - random.choices(string.ascii_lowercase + string.digits, k=7) - ) - _build_dir = "{}_{}".format(_build_dir, random_suffix) - - self._resolved_build_dir = self.app.config_location / _build_dir - return self._resolved_build_dir - - def _disable_locking(self) -> bool: - """Return true if locking should be disabled.""" - return self.execution_params.get("disable_locking", False) - class ParamResolver: """Parameter resolver.""" @@ -187,7 +69,7 @@ class ParamResolver: i = int(index_or_alias) if i not in range(len(resolved_params)): raise ConfigurationException( - "Invalid index {} for user params of command {}".format(i, cmd_name) + f"Invalid index {i} for user params of command {cmd_name}" ) param_value, param = resolved_params[i] else: @@ -198,9 +80,8 @@ class ParamResolver: if param is None: raise ConfigurationException( - "No user parameter for command '{}' with alias '{}'.".format( - cmd_name, index_or_alias - ) + f"No user parameter for command '{cmd_name}' with " + f"alias '{index_or_alias}'." ) if param_value: @@ -220,17 +101,12 @@ class ParamResolver: param_name = param.name separator = "" if param.name.endswith("=") else " " - return "{param_name}{separator}{param_value}".format( - param_name=param_name, - separator=separator, - param_value=param_value, - ) + return f"{param_name}{separator}{param_value}" if param.name is None: raise ConfigurationException( - "Missing user parameter with alias '{}' for command '{}'.".format( - index_or_alias, cmd_name - ) + f"Missing user parameter with alias '{index_or_alias}' for " + f"command '{cmd_name}'." ) return param.name # flag: just return the parameter name @@ -242,12 +118,12 @@ class ParamResolver: if backend_type == "system": backend = cast(Backend, self.ctx.system) backend_params = self.ctx.system_params - else: # Application or Tool backend + else: # Application backend backend = cast(Backend, self.ctx.app) backend_params = self.ctx.app_params if cmd_name not in backend.commands: - raise ConfigurationException("Command {} not found".format(cmd_name)) + raise ConfigurationException(f"Command {cmd_name} not found") if return_params: params = backend.resolved_parameters(cmd_name, backend_params) @@ -255,7 +131,7 @@ class ParamResolver: i = int(index_or_alias) if i not in range(len(params)): raise ConfigurationException( - "Invalid parameter index {} for command {}".format(i, cmd_name) + f"Invalid parameter index {i} for command {cmd_name}" ) param_value = params[i][0] @@ -269,20 +145,19 @@ class ParamResolver: if not param_value: raise ConfigurationException( ( - "No value for parameter with index or alias {} of command {}" - ).format(index_or_alias, cmd_name) + "No value for parameter with index or " + f"alias {index_or_alias} of command {cmd_name}." + ) ) return param_value if not index_or_alias.isnumeric(): - raise ConfigurationException("Bad command index {}".format(index_or_alias)) + raise ConfigurationException(f"Bad command index {index_or_alias}") i = int(index_or_alias) commands = backend.build_command(cmd_name, backend_params, self.param_resolver) if i not in range(len(commands)): - raise ConfigurationException( - "Invalid index {} for command {}".format(i, cmd_name) - ) + raise ConfigurationException(f"Invalid index {i} for command {cmd_name}") return commands[i] @@ -290,11 +165,11 @@ class ParamResolver: """Resolve variable value.""" if backend_type == "system": backend = cast(Backend, self.ctx.system) - else: # Application or Tool backend + else: # Application backend backend = cast(Backend, self.ctx.app) if var_name not in backend.variables: - raise ConfigurationException("Unknown variable {}".format(var_name)) + raise ConfigurationException(f"Unknown variable {var_name}") return backend.variables[var_name] @@ -309,7 +184,7 @@ class ParamResolver: # "system.commands.run.params:0" # Note: 'software' is included for backward compatibility. commands_and_params_match = re.match( - r"(?P<type>application|software|tool|system)[.]commands[.]" + r"(?P<type>application|software|system)[.]commands[.]" r"(?P<name>\w+)" r"(?P<params>[.]params|)[:]" r"(?P<index_or_alias>\w+)", @@ -329,7 +204,7 @@ class ParamResolver: # Note: 'software' is included for backward compatibility. variables_match = re.match( - r"(?P<type>application|software|tool|system)[.]variables:(?P<var_name>\w+)", + r"(?P<type>application|software|system)[.]variables:(?P<var_name>\w+)", param_name, ) if variables_match: @@ -344,9 +219,7 @@ class ParamResolver: index_or_alias = user_params_match["index_or_alias"] return self.resolve_user_params(cmd_name, index_or_alias, resolved_params) - raise ConfigurationException( - "Unable to resolve parameter {}".format(param_name) - ) + raise ConfigurationException(f"Unable to resolve parameter {param_name}") def param_resolver( self, @@ -357,24 +230,14 @@ class ParamResolver: """Resolve parameter value based on current execution context.""" # Note: 'software.*' is included for backward compatibility. resolved_param = None - if param_name in ["application.name", "tool.name", "software.name"]: + if param_name in ["application.name", "software.name"]: resolved_param = self.ctx.app.name - elif param_name in [ - "application.description", - "tool.description", - "software.description", - ]: + elif param_name in ["application.description", "software.description"]: resolved_param = self.ctx.app.description elif self.ctx.app.config_location and ( - param_name - in ["application.config_dir", "tool.config_dir", "software.config_dir"] + param_name in ["application.config_dir", "software.config_dir"] ): resolved_param = str(self.ctx.app.config_location.absolute()) - elif self.ctx.app.build_dir and ( - param_name - in ["application.build_dir", "tool.build_dir", "software.build_dir"] - ): - resolved_param = str(self.ctx.build_dir().absolute()) elif self.ctx.system is not None: if param_name == "system.name": resolved_param = self.ctx.system.name @@ -397,82 +260,6 @@ class ParamResolver: return self.param_resolver(param_name, cmd_name, resolved_params) -class Reporter: - """Report metrics from the simulation output.""" - - def __init__(self, parsers: Optional[List[OutputParser]] = None) -> None: - """Create an empty reporter (i.e. no parsers registered).""" - self.parsers: List[OutputParser] = parsers if parsers is not None else [] - self._report: Dict[str, Any] = defaultdict(lambda: defaultdict(dict)) - - def parse(self, output: bytearray) -> None: - """Parse output and append parsed metrics to internal report dict.""" - for parser in self.parsers: - # Merge metrics from different parsers (do not overwrite) - self._report[parser.name]["metrics"].update(parser(output)) - - def get_filtered_output(self, output: bytearray) -> bytearray: - """Filter the output according to each parser.""" - for parser in self.parsers: - output = parser.filter_out_parsed_content(output) - return output - - def report(self, ctx: ExecutionContext) -> Dict[str, Any]: - """Add static simulation info to parsed data and return the report.""" - report: Dict[str, Any] = defaultdict(dict) - # Add static simulation info - report.update(self._static_info(ctx)) - # Add metrics parsed from the output - for key, val in self._report.items(): - report[key].update(val) - return report - - @staticmethod - def save(report: Dict[str, Any], report_file: Path) -> None: - """Save the report to a JSON file.""" - with open(report_file, "w", encoding="utf-8") as file: - json.dump(report, file, indent=4) - - @staticmethod - def _compute_all_params(cli_params: List[str], backend: Backend) -> Dict[str, str]: - """ - Build a dict of all parameters, {name:value}. - - Param values taken from command line if specified, defaults otherwise. - """ - # map of params passed from the cli ["p1=v1","p2=v2"] -> {"p1":"v1", "p2":"v2"} - app_params_map = dict(parse_raw_parameter(expr) for expr in cli_params) - - # a map of params declared in the application, with values taken from the CLI, - # defaults otherwise - all_params = { - (p.alias or p.name): app_params_map.get( - cast(str, p.name), cast(str, p.default_value) - ) - for cmd in backend.commands.values() - for p in cmd.params - } - return cast(Dict[str, str], all_params) - - @staticmethod - def _static_info(ctx: ExecutionContext) -> Dict[str, Any]: - """Extract static simulation information from the context.""" - if ctx.system is None: - raise ValueError("No system available to report.") - - info = { - "system": { - "name": ctx.system.name, - "params": Reporter._compute_all_params(ctx.system_params, ctx.system), - }, - "application": { - "name": ctx.app.name, - "params": Reporter._compute_all_params(ctx.app_params, ctx.app), - }, - } - return info - - def validate_parameters( backend: Backend, command_names: List[str], params: List[str] ) -> None: @@ -487,9 +274,8 @@ def validate_parameters( if not acceptable: backend_type = "System" if isinstance(backend, System) else "Application" raise ValueError( - "{} parameter '{}' not valid for command '{}'".format( - backend_type, param, " or ".join(command_names) - ) + f"{backend_type} parameter '{param}' not valid for " + f"command '{' or '.join(command_names)}'." ) @@ -500,16 +286,14 @@ def get_application_by_name_and_system( applications = get_application(application_name, system_name) if not applications: raise ValueError( - "Application '{}' doesn't support the system '{}'".format( - application_name, system_name - ) + f"Application '{application_name}' doesn't support the " + f"system '{system_name}'." ) if len(applications) != 1: raise ValueError( - "Error during getting application {} for the system {}".format( - application_name, system_name - ) + f"Error during getting application {application_name} for the " + f"system {system_name}." ) return applications[0] @@ -521,259 +305,41 @@ def get_application_and_system( """Return application and system by provided names.""" system = get_system(system_name) if not system: - raise ValueError("System {} is not found".format(system_name)) + raise ValueError(f"System {system_name} is not found.") application = get_application_by_name_and_system(application_name, system_name) return application, system -# pylint: disable=too-many-arguments def run_application( application_name: str, application_params: List[str], system_name: str, system_params: List[str], - custom_deploy_data: List[DataPaths], - report_file: Optional[Path] = None, ) -> ExecutionContext: """Run application on the provided system.""" application, system = get_application_and_system(application_name, system_name) - validate_parameters(application, ["build", "run"], application_params) - validate_parameters(system, ["build", "run"], system_params) - - execution_params = ExecutionParams() - if isinstance(system, StandaloneSystem): - execution_params["disable_locking"] = True - execution_params["unique_build_dir"] = True + validate_parameters(application, ["run"], application_params) + validate_parameters(system, ["run"], system_params) ctx = ExecutionContext( app=application, app_params=application_params, system=system, system_params=system_params, - custom_deploy_data=custom_deploy_data, - execution_params=execution_params, - report_file=report_file, ) - with build_dir_manager(ctx): - if ctx.is_build_required: - execute_application_command_build(ctx) - - execute_application_command_run(ctx) - - return ctx - - -def execute_application_command_build(ctx: ExecutionContext) -> None: - """Execute application command 'build'.""" - with ExitStack() as context_stack: - for manager in get_context_managers("build", ctx): - context_stack.enter_context(manager(ctx)) - - build_dir = ctx.build_dir() - recreate_directory(build_dir) - - build_commands = ctx.app.build_command( - "build", ctx.app_params, ctx.param_resolver - ) - execute_commands_locally(build_commands, build_dir) - - -def execute_commands_locally(commands: List[str], cwd: Path) -> None: - """Execute list of commands locally.""" - for command in commands: - print("Running: {}".format(command)) - run_and_wait( - command, cwd, terminate_on_error=True, out=sys.stdout, err=sys.stderr - ) - - -def execute_application_command_run(ctx: ExecutionContext) -> None: - """Execute application command.""" - assert ctx.system is not None, "System must be provided." - if ctx.is_deploy_needed and not ctx.system.supports_deploy: - raise ConfigurationException( - "System {} does not support data deploy".format(ctx.system.name) - ) - - with ExitStack() as context_stack: - for manager in get_context_managers("run", ctx): - context_stack.enter_context(manager(ctx)) - - print("Generating commands to execute") - commands_to_run = build_run_commands(ctx) - - if ctx.system.connectable: - establish_connection(ctx) - - if ctx.system.supports_deploy: - deploy_data(ctx) - - for command in commands_to_run: - print("Running: {}".format(command)) - exit_code, ctx.stdout, ctx.stderr = ctx.system.run(command) - - if exit_code != 0: - print("Application exited with exit code {}".format(exit_code)) - - if ctx.reporter: - ctx.reporter.parse(ctx.stdout) - ctx.stdout = ctx.reporter.get_filtered_output(ctx.stdout) - - if ctx.reporter: - report = ctx.reporter.report(ctx) - ctx.reporter.save(report, cast(Path, ctx.report_file)) - - -def establish_connection( - ctx: ExecutionContext, retries: int = 90, interval: float = 15.0 -) -> None: - """Establish connection with the system.""" - assert ctx.system is not None, "System is required." - host, port = ctx.system.connection_details() - print( - "Trying to establish connection with '{}:{}' - " - "{} retries every {} seconds ".format(host, port, retries, interval), - end="", + logger.debug("Generating commands to execute") + commands_to_run = ctx.system.build_command( + "run", ctx.system_params, ctx.param_resolver ) - try: - for _ in range(retries): - print(".", end="", flush=True) - - if ctx.system.establish_connection(): - break - - if isinstance(ctx.system, ControlledSystem) and not ctx.system.is_running(): - print( - "\n\n---------- {} execution failed ----------".format( - ctx.system.name - ) - ) - stdout, stderr = ctx.system.get_output() - print(stdout) - print(stderr) - - raise Exception("System is not running") + for command in commands_to_run: + logger.debug("Running: %s", command) + exit_code, ctx.stdout, ctx.stderr = ctx.system.run(command) - wait(interval) - else: - raise ConnectionException("Couldn't connect to '{}:{}'.".format(host, port)) - finally: - print() - - -def wait(interval: float) -> None: - """Wait for a period of time.""" - time.sleep(interval) - - -def deploy_data(ctx: ExecutionContext) -> None: - """Deploy data to the system.""" - assert ctx.system is not None, "System is required." - for item in itertools.chain(ctx.app.get_deploy_data(), ctx.custom_deploy_data): - print("Deploying {} onto {}".format(item.src, item.dst)) - ctx.system.deploy(item.src, item.dst) - - -def build_run_commands(ctx: ExecutionContext) -> List[str]: - """Build commands to run application.""" - if isinstance(ctx.system, StandaloneSystem): - return ctx.system.build_command("run", ctx.system_params, ctx.param_resolver) - - return ctx.app.build_command("run", ctx.app_params, ctx.param_resolver) - - -@contextmanager -def controlled_system_manager(ctx: ExecutionContext) -> Generator[None, None, None]: - """Context manager used for system initialisation before run.""" - system = cast(ControlledSystem, ctx.system) - commands = system.build_command("run", ctx.system_params, ctx.param_resolver) - pid_file_path: Optional[Path] = None - if ctx.is_locking_required: - file_lock_path = get_file_lock_path(ctx) - pid_file_path = file_lock_path.parent / "{}.pid".format(file_lock_path.stem) - - system.start(commands, ctx.is_locking_required, pid_file_path) - try: - yield - finally: - print("Shutting down sequence...") - print("Stopping {}... (It could take few seconds)".format(system.name)) - system.stop(wait=True) - print("{} stopped successfully.".format(system.name)) - - -@contextmanager -def lock_execution_manager(ctx: ExecutionContext) -> Generator[None, None, None]: - """Lock execution manager.""" - file_lock_path = get_file_lock_path(ctx) - file_lock = FileLock(str(file_lock_path)) - - try: - file_lock.acquire(timeout=1) - except Timeout as error: - raise AnotherInstanceIsRunningException() from error - - try: - yield - finally: - file_lock.release() - - -def get_file_lock_path(ctx: ExecutionContext, lock_dir: Path = Path("/tmp")) -> Path: - """Get file lock path.""" - lock_modules = [] - if ctx.app.lock: - lock_modules.append(ctx.app.name) - if ctx.system is not None and ctx.system.lock: - lock_modules.append(ctx.system.name) - lock_filename = "" - if lock_modules: - lock_filename = "_".join(["middleware"] + lock_modules) + ".lock" - - if lock_filename: - lock_filename = resolve_all_parameters(lock_filename, ctx.param_resolver) - lock_filename = valid_for_filename(lock_filename) - - if not lock_filename: - raise ConfigurationException("No filename for lock provided") - - if not isinstance(lock_dir, Path) or not lock_dir.is_dir(): - raise ConfigurationException( - "Invalid directory {} for lock files provided".format(lock_dir) - ) - - return lock_dir / lock_filename - - -@contextmanager -def build_dir_manager(ctx: ExecutionContext) -> Generator[None, None, None]: - """Build directory manager.""" - try: - yield - finally: - if ( - ctx.is_build_required - and ctx.is_unique_build_dir_required - and ctx.build_dir().is_dir() - ): - remove_directory(ctx.build_dir()) + if exit_code != 0: + logger.warning("Application exited with exit code %i", exit_code) - -def get_context_managers( - command_name: str, ctx: ExecutionContext -) -> Sequence[Callable[[ExecutionContext], ContextManager[None]]]: - """Get context manager for the system.""" - managers = [] - - if ctx.is_locking_required: - managers.append(lock_execution_manager) - - if command_name == "run": - if isinstance(ctx.system, ControlledSystem): - managers.append(controlled_system_manager) - - return managers + return ctx diff --git a/src/mlia/backend/fs.py b/src/mlia/backend/fs.py index 9979fcb..9fb53b1 100644 --- a/src/mlia/backend/fs.py +++ b/src/mlia/backend/fs.py @@ -4,7 +4,6 @@ import re import shutil from pathlib import Path -from typing import Any from typing import Literal from typing import Optional @@ -30,7 +29,7 @@ def get_backends_path(name: ResourceType) -> Path: if resource_path.is_dir(): return resource_path - raise ResourceWarning("Resource '{}' not found.".format(name)) + raise ResourceWarning(f"Resource '{name}' not found.") def copy_directory_content(source: Path, destination: Path) -> None: @@ -51,10 +50,10 @@ def remove_resource(resource_directory: str, resource_type: ResourceType) -> Non resource_location = resources / resource_directory if not resource_location.exists(): - raise Exception("Resource {} does not exist".format(resource_directory)) + raise Exception(f"Resource {resource_directory} does not exist") if not resource_location.is_dir(): - raise Exception("Wrong resource {}".format(resource_directory)) + raise Exception(f"Wrong resource {resource_directory}") shutil.rmtree(resource_location) @@ -74,7 +73,7 @@ def recreate_directory(directory_path: Optional[Path]) -> None: if directory_path.exists() and not directory_path.is_dir(): raise Exception( - "Path {} does exist and it is not a directory".format(str(directory_path)) + f"Path {str(directory_path)} does exist and it is not a directory." ) if directory_path.is_dir(): @@ -83,33 +82,6 @@ def recreate_directory(directory_path: Optional[Path]) -> None: directory_path.mkdir() -def read_file(file_path: Path, mode: Optional[str] = None) -> Any: - """Read file as string or bytearray.""" - if file_path.is_file(): - if mode is not None: - # Ignore pylint warning because mode can be 'binary' as well which - # is not compatible with specifying encodings. - with open(file_path, mode) as file: # pylint: disable=unspecified-encoding - return file.read() - else: - with open(file_path, encoding="utf-8") as file: - return file.read() - - if mode == "rb": - return b"" - return "" - - -def read_file_as_string(file_path: Path) -> str: - """Read file as string.""" - return str(read_file(file_path)) - - -def read_file_as_bytearray(file_path: Path) -> bytearray: - """Read a file as bytearray.""" - return bytearray(read_file(file_path, mode="rb")) - - def valid_for_filename(value: str, replacement: str = "") -> str: """Replace non alpha numeric characters.""" return re.sub(r"[^\w.]", replacement, value, flags=re.ASCII) diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py index 3a1016c..8d8246d 100644 --- a/src/mlia/backend/manager.py +++ b/src/mlia/backend/manager.py @@ -2,27 +2,25 @@ # SPDX-License-Identifier: Apache-2.0 """Module for backend integration.""" import logging -import re from abc import ABC from abc import abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Any from typing import Dict from typing import List from typing import Literal from typing import Optional +from typing import Set from typing import Tuple from mlia.backend.application import get_available_applications from mlia.backend.application import install_application -from mlia.backend.common import DataPaths from mlia.backend.execution import ExecutionContext from mlia.backend.execution import run_application +from mlia.backend.output_consumer import Base64OutputConsumer +from mlia.backend.output_consumer import OutputConsumer from mlia.backend.system import get_available_systems from mlia.backend.system import install_system -from mlia.utils.proc import OutputConsumer -from mlia.utils.proc import RunningCommand logger = logging.getLogger(__name__) @@ -128,89 +126,55 @@ class ExecutionParams: system: str application_params: List[str] system_params: List[str] - deploy_params: List[str] class LogWriter(OutputConsumer): """Redirect output to the logger.""" - def feed(self, line: str) -> None: + def feed(self, line: str) -> bool: """Process line from the output.""" logger.debug(line.strip()) + return False -class GenericInferenceOutputParser(OutputConsumer): +class GenericInferenceOutputParser(Base64OutputConsumer): """Generic inference app output parser.""" - PATTERNS = { - name: tuple(re.compile(pattern, re.IGNORECASE) for pattern in patterns) - for name, patterns in ( - ( - "npu_active_cycles", - ( - r"NPU ACTIVE cycles: (?P<value>\d+)", - r"NPU ACTIVE: (?P<value>\d+) cycles", - ), - ), - ( - "npu_idle_cycles", - ( - r"NPU IDLE cycles: (?P<value>\d+)", - r"NPU IDLE: (?P<value>\d+) cycles", - ), - ), - ( - "npu_total_cycles", - ( - r"NPU TOTAL cycles: (?P<value>\d+)", - r"NPU TOTAL: (?P<value>\d+) cycles", - ), - ), - ( - "npu_axi0_rd_data_beat_received", - ( - r"NPU AXI0_RD_DATA_BEAT_RECEIVED beats: (?P<value>\d+)", - r"NPU AXI0_RD_DATA_BEAT_RECEIVED: (?P<value>\d+) beats", - ), - ), - ( - "npu_axi0_wr_data_beat_written", - ( - r"NPU AXI0_WR_DATA_BEAT_WRITTEN beats: (?P<value>\d+)", - r"NPU AXI0_WR_DATA_BEAT_WRITTEN: (?P<value>\d+) beats", - ), - ), - ( - "npu_axi1_rd_data_beat_received", - ( - r"NPU AXI1_RD_DATA_BEAT_RECEIVED beats: (?P<value>\d+)", - r"NPU AXI1_RD_DATA_BEAT_RECEIVED: (?P<value>\d+) beats", - ), - ), - ) - } - def __init__(self) -> None: """Init generic inference output parser instance.""" - self.result: Dict = {} - - def feed(self, line: str) -> None: - """Feed new line to the parser.""" - for name, patterns in self.PATTERNS.items(): - for pattern in patterns: - match = pattern.search(line) - - if match: - self.result[name] = int(match["value"]) - return + super().__init__() + self._map = { + "NPU ACTIVE": "npu_active_cycles", + "NPU IDLE": "npu_idle_cycles", + "NPU TOTAL": "npu_total_cycles", + "NPU AXI0_RD_DATA_BEAT_RECEIVED": "npu_axi0_rd_data_beat_received", + "NPU AXI0_WR_DATA_BEAT_WRITTEN": "npu_axi0_wr_data_beat_written", + "NPU AXI1_RD_DATA_BEAT_RECEIVED": "npu_axi1_rd_data_beat_received", + } + + @property + def result(self) -> Dict: + """Merge the raw results and map the names to the right output names.""" + merged_result = {} + for raw_result in self.parsed_output: + for profiling_result in raw_result: + for sample in profiling_result["samples"]: + name, values = (sample["name"], sample["value"]) + if name in merged_result: + raise KeyError( + f"Duplicate key '{name}' in base64 output.", + ) + new_name = self._map[name] + merged_result[new_name] = values[0] + return merged_result def is_ready(self) -> bool: """Return true if all expected data has been parsed.""" - return self.result.keys() == self.PATTERNS.keys() + return set(self.result.keys()) == set(self._map.values()) - def missed_keys(self) -> List[str]: - """Return list of the keys that have not been found in the output.""" - return sorted(self.PATTERNS.keys() - self.result.keys()) + def missed_keys(self) -> Set[str]: + """Return a set of the keys that have not been found in the output.""" + return set(self._map.values()) - set(self.result.keys()) class BackendRunner: @@ -274,24 +238,12 @@ class BackendRunner: @staticmethod def run_application(execution_params: ExecutionParams) -> ExecutionContext: """Run requested application.""" - - def to_data_paths(paths: str) -> DataPaths: - """Split input into two and create new DataPaths object.""" - src, dst = paths.split(sep=":", maxsplit=1) - return DataPaths(Path(src), dst) - - deploy_data_paths = [ - to_data_paths(paths) for paths in execution_params.deploy_params - ] - ctx = run_application( execution_params.application, execution_params.application_params, execution_params.system, execution_params.system_params, - deploy_data_paths, ) - return ctx @staticmethod @@ -305,7 +257,6 @@ class GenericInferenceRunner(ABC): def __init__(self, backend_runner: BackendRunner): """Init generic inference runner instance.""" self.backend_runner = backend_runner - self.running_inference: Optional[RunningCommand] = None def run( self, model_info: ModelInfo, output_consumers: List[OutputConsumer] @@ -315,27 +266,12 @@ class GenericInferenceRunner(ABC): ctx = self.backend_runner.run_application(execution_params) if ctx.stdout is not None: - self.consume_output(ctx.stdout, output_consumers) - - def stop(self) -> None: - """Stop running inference.""" - if self.running_inference is None: - return - - self.running_inference.stop() + ctx.stdout = self.consume_output(ctx.stdout, output_consumers) @abstractmethod def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams: """Get execution params for the provided model.""" - def __enter__(self) -> "GenericInferenceRunner": - """Enter context.""" - return self - - def __exit__(self, *_args: Any) -> None: - """Exit context.""" - self.stop() - def check_system_and_application(self, system_name: str, app_name: str) -> None: """Check if requested system and application installed.""" if not self.backend_runner.is_system_installed(system_name): @@ -348,11 +284,23 @@ class GenericInferenceRunner(ABC): ) @staticmethod - def consume_output(output: bytearray, consumers: List[OutputConsumer]) -> None: - """Pass program's output to the consumers.""" - for line in output.decode("utf8").splitlines(): + def consume_output(output: bytearray, consumers: List[OutputConsumer]) -> bytearray: + """ + Pass program's output to the consumers and filter it. + + Returns the filtered output. + """ + filtered_output = bytearray() + for line_bytes in output.splitlines(): + line = line_bytes.decode("utf-8") + remove_line = False for consumer in consumers: - consumer.feed(line) + if consumer.feed(line): + remove_line = True + if not remove_line: + filtered_output.extend(line_bytes) + + return filtered_output class GenericInferenceRunnerEthosU(GenericInferenceRunner): @@ -408,7 +356,6 @@ class GenericInferenceRunnerEthosU(GenericInferenceRunner): self.system_name, [], system_params, - [], ) @@ -422,20 +369,18 @@ def estimate_performance( model_info: ModelInfo, device_info: DeviceInfo, backend: str ) -> PerformanceMetrics: """Get performance estimations.""" - with get_generic_runner(device_info, backend) as generic_runner: - output_parser = GenericInferenceOutputParser() - output_consumers = [output_parser, LogWriter()] + output_parser = GenericInferenceOutputParser() + output_consumers = [output_parser, LogWriter()] - generic_runner.run(model_info, output_consumers) + generic_runner = get_generic_runner(device_info, backend) + generic_runner.run(model_info, output_consumers) - if not output_parser.is_ready(): - missed_data = ",".join(output_parser.missed_keys()) - logger.debug( - "Unable to get performance metrics, missed data %s", missed_data - ) - raise Exception("Unable to get performance metrics, insufficient data") + if not output_parser.is_ready(): + missed_data = ",".join(output_parser.missed_keys()) + logger.debug("Unable to get performance metrics, missed data %s", missed_data) + raise Exception("Unable to get performance metrics, insufficient data") - return PerformanceMetrics(**output_parser.result) + return PerformanceMetrics(**output_parser.result) def get_backend_runner() -> BackendRunner: diff --git a/src/mlia/backend/output_consumer.py b/src/mlia/backend/output_consumer.py new file mode 100644 index 0000000..bac4186 --- /dev/null +++ b/src/mlia/backend/output_consumer.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Output consumers module.""" +import base64 +import json +import re +from typing import List +from typing import Protocol +from typing import runtime_checkable + + +@runtime_checkable +class OutputConsumer(Protocol): + """Protocol to consume output.""" + + def feed(self, line: str) -> bool: + """ + Feed a new line to be parsed. + + Return True if the line should be removed from the output. + """ + + +class Base64OutputConsumer(OutputConsumer): + """ + Parser to extract base64-encoded JSON from tagged standard output. + + Example of the tagged output: + ``` + # Encoded JSON: {"test": 1234} + <metrics>eyJ0ZXN0IjogMTIzNH0</metrics> + ``` + """ + + TAG_NAME = "metrics" + + def __init__(self) -> None: + """Set up the regular expression to extract tagged strings.""" + self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)</{self.TAG_NAME}>") + self.parsed_output: List = [] + + def feed(self, line: str) -> bool: + """ + Parse the output line and save the decoded output. + + Returns True if the line contains tagged output. + + Example: + Using the tagged output from the class docs the parser should collect + the following: + ``` + [ + {"test": 1234} + ] + ``` + """ + res_b64 = self._regex.search(line) + if res_b64: + res_json = base64.b64decode(res_b64.group(1), validate=True) + res = json.loads(res_json) + self.parsed_output.append(res) + # Remove this line from the output, i.e. consume it, as it + # does not contain any human readable content. + return True + + return False diff --git a/src/mlia/backend/output_parser.py b/src/mlia/backend/output_parser.py deleted file mode 100644 index 111772a..0000000 --- a/src/mlia/backend/output_parser.py +++ /dev/null @@ -1,176 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Definition of output parsers (including base class OutputParser).""" -import base64 -import json -import re -from abc import ABC -from abc import abstractmethod -from typing import Any -from typing import Dict -from typing import Union - - -class OutputParser(ABC): - """Abstract base class for output parsers.""" - - def __init__(self, name: str) -> None: - """Set up the name of the parser.""" - super().__init__() - self.name = name - - @abstractmethod - def __call__(self, output: bytearray) -> Dict[str, Any]: - """Parse the output and return a map of names to metrics.""" - return {} - - # pylint: disable=no-self-use - def filter_out_parsed_content(self, output: bytearray) -> bytearray: - """ - Filter out the parsed content from the output. - - Does nothing by default. Can be overridden in subclasses. - """ - return output - - -class RegexOutputParser(OutputParser): - """Parser of standard output data using regular expressions.""" - - _TYPE_MAP = {"str": str, "float": float, "int": int} - - def __init__( - self, - name: str, - regex_config: Dict[str, Dict[str, str]], - ) -> None: - """ - Set up the parser with the regular expressions. - - The regex_config is mapping from a name to a dict with keys 'pattern' - and 'type': - - The 'pattern' holds the regular expression that must contain exactly - one capturing parenthesis - - The 'type' can be one of ['str', 'float', 'int']. - - Example: - ``` - {"Metric1": {"pattern": ".*= *(.*)", "type": "str"}} - ``` - - The different regular expressions from the config are combined using - non-capturing parenthesis, i.e. regular expressions must not overlap - if more than one match per line is expected. - """ - super().__init__(name) - - self._verify_config(regex_config) - self._regex_cfg = regex_config - - # Compile regular expression to match in the output - self._regex = re.compile( - "|".join("(?:{0})".format(x["pattern"]) for x in self._regex_cfg.values()) - ) - - def __call__(self, output: bytearray) -> Dict[str, Union[str, float, int]]: - """ - Parse the output and return a map of names to metrics. - - Example: - Assuming a regex_config as used as example in `__init__()` and the - following output: - ``` - Simulation finished: - SIMULATION_STATUS = SUCCESS - Simulation DONE - ``` - Then calling the parser should return the following dict: - ``` - { - "Metric1": "SUCCESS" - } - ``` - """ - metrics = {} - output_str = output.decode("utf-8") - results = self._regex.findall(output_str) - for line_result in results: - for idx, (name, cfg) in enumerate(self._regex_cfg.items()): - # The result(s) returned by findall() are either a single string - # or a tuple (depending on the number of groups etc.) - result = ( - line_result if isinstance(line_result, str) else line_result[idx] - ) - if result: - mapped_result = self._TYPE_MAP[cfg["type"]](result) - metrics[name] = mapped_result - return metrics - - def _verify_config(self, regex_config: Dict[str, Dict[str, str]]) -> None: - """Make sure we have a valid regex_config. - - I.e. - - Exactly one capturing parenthesis per pattern - - Correct types - """ - for name, cfg in regex_config.items(): - # Check that there is one capturing group defined in the pattern. - regex = re.compile(cfg["pattern"]) - if regex.groups != 1: - raise ValueError( - f"Pattern for metric '{name}' must have exactly one " - f"capturing parenthesis, but it has {regex.groups}." - ) - # Check if type is supported - if not cfg["type"] in self._TYPE_MAP: - raise TypeError( - f"Type '{cfg['type']}' for metric '{name}' is not " - f"supported. Choose from: {list(self._TYPE_MAP.keys())}." - ) - - -class Base64OutputParser(OutputParser): - """ - Parser to extract base64-encoded JSON from tagged standard output. - - Example of the tagged output: - ``` - # Encoded JSON: {"test": 1234} - <metrics>eyJ0ZXN0IjogMTIzNH0</metrics> - ``` - """ - - TAG_NAME = "metrics" - - def __init__(self, name: str) -> None: - """Set up the regular expression to extract tagged strings.""" - super().__init__(name) - self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)</{self.TAG_NAME}>") - - def __call__(self, output: bytearray) -> Dict[str, Any]: - """ - Parse the output and return a map of index (as string) to decoded JSON. - - Example: - Using the tagged output from the class docs the parser should return - the following dict: - ``` - { - "0": {"test": 1234} - } - ``` - """ - metrics = {} - output_str = output.decode("utf-8") - results = self._regex.findall(output_str) - for idx, result_base64 in enumerate(results): - result_json = base64.b64decode(result_base64, validate=True) - result = json.loads(result_json) - metrics[str(idx)] = result - - return metrics - - def filter_out_parsed_content(self, output: bytearray) -> bytearray: - """Filter out base64-encoded content from the output.""" - output_str = self._regex.sub("", output.decode("utf-8")) - return bytearray(output_str.encode("utf-8")) diff --git a/src/mlia/backend/proc.py b/src/mlia/backend/proc.py index 90ff414..a4c0be3 100644 --- a/src/mlia/backend/proc.py +++ b/src/mlia/backend/proc.py @@ -5,7 +5,6 @@ This module contains all classes and functions for dealing with Linux processes. """ -import csv import datetime import logging import shlex @@ -14,11 +13,9 @@ import time from pathlib import Path from typing import Any from typing import List -from typing import NamedTuple from typing import Optional from typing import Tuple -import psutil from sh import Command from sh import CommandNotFound from sh import ErrorReturnCode @@ -26,6 +23,8 @@ from sh import RunningCommand from mlia.backend.fs import valid_for_filename +logger = logging.getLogger(__name__) + class CommandFailedException(Exception): """Exception for failed command execution.""" @@ -50,7 +49,7 @@ class ShellCommand: _bg: bool = True, _out: Any = None, _err: Any = None, - _search_paths: Optional[List[Path]] = None + _search_paths: Optional[List[Path]] = None, ) -> RunningCommand: """Run the shell command with the given arguments. @@ -86,7 +85,7 @@ class ShellCommand: """Construct and returns the paths of stdout/stderr files.""" timestamp = datetime.datetime.now().timestamp() base_path = Path(base_log_path) - base = base_path / "{}_{}".format(valid_for_filename(cmd, "_"), timestamp) + base = base_path / f"{valid_for_filename(cmd, '_')}_{timestamp}" stdout = base.with_suffix(".out") stderr = base.with_suffix(".err") try: @@ -164,7 +163,7 @@ def run_and_wait( except Exception as error: is_running = running_cmd is not None and running_cmd.is_alive() if terminate_on_error and is_running: - print("Terminating ...") + logger.debug("Terminating ...") terminate_command(running_cmd) raise error @@ -184,87 +183,15 @@ def terminate_command( time.sleep(wait_period) if not running_cmd.is_alive(): return - print( - "Unable to terminate process {}. Sending SIGTERM...".format( - running_cmd.process.pid - ) + logger.error( + "Unable to terminate process %i. Sending SIGTERM...", + running_cmd.process.pid, ) running_cmd.process.signal_group(signal.SIGTERM) except ProcessLookupError: pass -def terminate_external_process( - process: psutil.Process, - wait_period: float = 0.5, - number_of_attempts: int = 20, - wait_for_termination: float = 5.0, -) -> None: - """Terminate external process.""" - try: - process.terminate() - for _ in range(number_of_attempts): - if not process.is_running(): - return - time.sleep(wait_period) - - if process.is_running(): - process.terminate() - time.sleep(wait_for_termination) - except psutil.Error: - print("Unable to terminate process") - - -class ProcessInfo(NamedTuple): - """Process information.""" - - name: str - executable: str - cwd: str - pid: int - - -def save_process_info(pid: int, pid_file: Path) -> None: - """Save process information to file.""" - try: - parent = psutil.Process(pid) - children = parent.children(recursive=True) - family = [parent] + children - - with open(pid_file, "w", encoding="utf-8") as file: - csv_writer = csv.writer(file) - for member in family: - process_info = ProcessInfo( - name=member.name(), - executable=member.exe(), - cwd=member.cwd(), - pid=member.pid, - ) - csv_writer.writerow(process_info) - except psutil.NoSuchProcess: - # if process does not exist or finishes before - # function call then nothing could be saved - # just ignore this exception and exit - pass - - -def read_process_info(pid_file: Path) -> List[ProcessInfo]: - """Read information about previous system processes.""" - if not pid_file.is_file(): - return [] - - result = [] - with open(pid_file, encoding="utf-8") as file: - csv_reader = csv.reader(file) - for row in csv_reader: - name, executable, cwd, pid = row - result.append( - ProcessInfo(name=name, executable=executable, cwd=cwd, pid=int(pid)) - ) - - return result - - def print_command_stdout(command: RunningCommand) -> None: """Print the stdout of a command. diff --git a/src/mlia/backend/protocol.py b/src/mlia/backend/protocol.py deleted file mode 100644 index ebfe69a..0000000 --- a/src/mlia/backend/protocol.py +++ /dev/null @@ -1,325 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain protocol related classes and functions.""" -from abc import ABC -from abc import abstractmethod -from contextlib import closing -from pathlib import Path -from typing import Any -from typing import cast -from typing import Iterable -from typing import Optional -from typing import Tuple -from typing import Union - -import paramiko - -from mlia.backend.common import ConfigurationException -from mlia.backend.config import LocalProtocolConfig -from mlia.backend.config import SSHConfig -from mlia.backend.proc import run_and_wait - - -# Redirect all paramiko thread exceptions to a file otherwise these will be -# printed to stderr. -paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO) - - -class SSHConnectionException(Exception): - """SSH connection exception.""" - - -class SupportsClose(ABC): - """Class indicates support of close operation.""" - - @abstractmethod - def close(self) -> None: - """Close protocol session.""" - - -class SupportsDeploy(ABC): - """Class indicates support of deploy operation.""" - - @abstractmethod - def deploy(self, src: Path, dst: str, retry: bool = True) -> None: - """Abstract method for deploy data.""" - - -class SupportsConnection(ABC): - """Class indicates that protocol uses network connections.""" - - @abstractmethod - def establish_connection(self) -> bool: - """Establish connection with underlying system.""" - - @abstractmethod - def connection_details(self) -> Tuple[str, int]: - """Return connection details (host, port).""" - - -class Protocol(ABC): - """Abstract class for representing the protocol.""" - - def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: - """Initialize the class using a dict.""" - self.__dict__.update(iterable, **kwargs) - self._validate() - - @abstractmethod - def _validate(self) -> None: - """Abstract method for config data validation.""" - - @abstractmethod - def run( - self, command: str, retry: bool = False - ) -> Tuple[int, bytearray, bytearray]: - """ - Abstract method for running commands. - - Returns a tuple: (exit_code, stdout, stderr) - """ - - -class CustomSFTPClient(paramiko.SFTPClient): - """Class for creating a custom sftp client.""" - - def put_dir(self, source: Path, target: str) -> None: - """Upload the source directory to the target path. - - The target directory needs to exists and the last directory of the - source will be created under the target with all its content. - """ - # Create the target directory - self._mkdir(target, ignore_existing=True) - # Create the last directory in the source on the target - self._mkdir("{}/{}".format(target, source.name), ignore_existing=True) - # Go through the whole content of source - for item in sorted(source.glob("**/*")): - relative_path = item.relative_to(source.parent) - remote_target = target / relative_path - if item.is_file(): - self.put(str(item), str(remote_target)) - else: - self._mkdir(str(remote_target), ignore_existing=True) - - def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None: - """Extend mkdir functionality. - - This version adds an option to not fail if the folder exists. - """ - try: - super().mkdir(path, mode) - except IOError as error: - if ignore_existing: - pass - else: - raise error - - -class LocalProtocol(Protocol): - """Class for local protocol.""" - - protocol: str - cwd: Path - - def run( - self, command: str, retry: bool = False - ) -> Tuple[int, bytearray, bytearray]: - """ - Run command locally. - - Returns a tuple: (exit_code, stdout, stderr) - """ - if not isinstance(self.cwd, Path) or not self.cwd.is_dir(): - raise ConfigurationException("Wrong working directory {}".format(self.cwd)) - - stdout = bytearray() - stderr = bytearray() - - return run_and_wait( - command, self.cwd, terminate_on_error=True, out=stdout, err=stderr - ) - - def _validate(self) -> None: - """Validate protocol configuration.""" - assert hasattr(self, "protocol") and self.protocol == "local" - assert hasattr(self, "cwd") - - -class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection): - """Class for SSH protocol.""" - - protocol: str - username: str - password: str - hostname: str - port: int - - def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: - """Initialize the class using a dict.""" - super().__init__(iterable, **kwargs) - # Internal state to store if the system is connectable. It will be set - # to true at the first connection instance - self.client: Optional[paramiko.client.SSHClient] = None - self.port = int(self.port) - - def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: - """ - Run command over SSH. - - Returns a tuple: (exit_code, stdout, stderr) - """ - transport = self._get_transport() - with closing(transport.open_session()) as channel: - # Enable shell's .profile settings and execute command - channel.exec_command("bash -l -c '{}'".format(command)) - exit_status = -1 - stdout = bytearray() - stderr = bytearray() - while True: - if channel.exit_status_ready(): - exit_status = channel.recv_exit_status() - # Call it one last time to read any leftover in the channel - self._recv_stdout_err(channel, stdout, stderr) - break - self._recv_stdout_err(channel, stdout, stderr) - - return exit_status, stdout, stderr - - def deploy(self, src: Path, dst: str, retry: bool = True) -> None: - """Deploy src to remote dst over SSH. - - src and dst should be path to a file or directory. - """ - transport = self._get_transport() - sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport)) - - with closing(sftp): - if src.is_dir(): - sftp.put_dir(src, dst) - elif src.is_file(): - sftp.put(str(src), dst) - else: - raise Exception("Deploy error: file type not supported") - - # After the deployment of files, sync the remote filesystem to flush - # buffers to hard disk - self.run("sync") - - def close(self) -> None: - """Close protocol session.""" - if self.client is not None: - print("Try syncing remote file system...") - # Before stopping the system, we try to run sync to make sure all - # data are flushed on disk. - self.run("sync", retry=False) - self._close_client(self.client) - - def establish_connection(self) -> bool: - """Establish connection with underlying system.""" - if self.client is not None: - return True - - self.client = self._connect() - return self.client is not None - - def _get_transport(self) -> paramiko.transport.Transport: - """Get transport.""" - self.establish_connection() - - if self.client is None: - raise SSHConnectionException( - "Couldn't connect to '{}:{}'.".format(self.hostname, self.port) - ) - - transport = self.client.get_transport() - if not transport: - raise Exception("Unable to get transport") - - return transport - - def connection_details(self) -> Tuple[str, int]: - """Return connection details of underlying system.""" - return (self.hostname, self.port) - - def _connect(self) -> Optional[paramiko.client.SSHClient]: - """Try to establish connection.""" - client: Optional[paramiko.client.SSHClient] = None - try: - client = paramiko.client.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect( - self.hostname, - self.port, - self.username, - self.password, - # next parameters should be set to False to disable authentication - # using ssh keys - allow_agent=False, - look_for_keys=False, - ) - return client - except ( - # OSError raised on first attempt to connect when running inside Docker - OSError, - paramiko.ssh_exception.NoValidConnectionsError, - paramiko.ssh_exception.SSHException, - ): - # even if connection is not established socket could be still - # open, it should be closed - self._close_client(client) - - return None - - @staticmethod - def _close_client(client: Optional[paramiko.client.SSHClient]) -> None: - """Close ssh client.""" - try: - if client is not None: - client.close() - except Exception: # pylint: disable=broad-except - pass - - @classmethod - def _recv_stdout_err( - cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray - ) -> None: - """Read from channel to stdout/stder.""" - chunk_size = 512 - if channel.recv_ready(): - stdout_chunk = channel.recv(chunk_size) - stdout.extend(stdout_chunk) - if channel.recv_stderr_ready(): - stderr_chunk = channel.recv_stderr(chunk_size) - stderr.extend(stderr_chunk) - - def _validate(self) -> None: - """Check if there are all the info for establishing the connection.""" - assert hasattr(self, "protocol") and self.protocol == "ssh" - assert hasattr(self, "username") - assert hasattr(self, "password") - assert hasattr(self, "hostname") - assert hasattr(self, "port") - - -class ProtocolFactory: - """Factory class to return the appropriate Protocol class.""" - - @staticmethod - def get_protocol( - config: Optional[Union[SSHConfig, LocalProtocolConfig]], - **kwargs: Union[str, Path, None] - ) -> Union[SSHProtocol, LocalProtocol]: - """Return the right protocol instance based on the config.""" - if not config: - raise ValueError("No protocol config provided") - - protocol = config["protocol"] - if protocol == "ssh": - return SSHProtocol(config) - - if protocol == "local": - cwd = kwargs.get("cwd") - return LocalProtocol(config, cwd=cwd) - - raise ValueError("Protocol not supported: '{}'".format(protocol)) diff --git a/src/mlia/backend/source.py b/src/mlia/backend/source.py index dcf6835..f80a774 100644 --- a/src/mlia/backend/source.py +++ b/src/mlia/backend/source.py @@ -63,11 +63,11 @@ class DirectorySource(Source): def install_into(self, destination: Path) -> None: """Install source into destination directory.""" if not destination.is_dir(): - raise ConfigurationException("Wrong destination {}".format(destination)) + raise ConfigurationException(f"Wrong destination {destination}.") if not self.directory_path.is_dir(): raise ConfigurationException( - "Directory {} does not exist".format(self.directory_path) + f"Directory {self.directory_path} does not exist." ) copy_directory_content(self.directory_path, destination) @@ -112,7 +112,7 @@ class TarArchiveSource(Source): "Archive has no top level directory" ) from error_no_config - config_path = "{}/{}".format(top_level_dir, BACKEND_CONFIG_FILE) + config_path = f"{top_level_dir}/{BACKEND_CONFIG_FILE}" config_entry = archive.getmember(config_path) self._has_top_level_folder = True @@ -149,7 +149,7 @@ class TarArchiveSource(Source): def install_into(self, destination: Path) -> None: """Install source into destination directory.""" if not destination.is_dir(): - raise ConfigurationException("Wrong destination {}".format(destination)) + raise ConfigurationException(f"Wrong destination {destination}.") with self._open(self.archive_path) as archive: archive.extractall(destination) @@ -157,14 +157,12 @@ class TarArchiveSource(Source): def _open(self, archive_path: Path) -> TarFile: """Open archive file.""" if not archive_path.is_file(): - raise ConfigurationException("File {} does not exist".format(archive_path)) + raise ConfigurationException(f"File {archive_path} does not exist.") if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"): mode = "r:gz" else: - raise ConfigurationException( - "Unsupported archive type {}".format(archive_path) - ) + raise ConfigurationException(f"Unsupported archive type {archive_path}.") # The returned TarFile object can be used as a context manager (using # 'with') by the calling instance. @@ -181,7 +179,7 @@ def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]: if source_path.is_dir(): return DirectorySource(source_path) - raise ConfigurationException("Unable to read {}".format(source_path)) + raise ConfigurationException(f"Unable to read {source_path}.") def create_destination_and_install(source: Source, resource_path: Path) -> None: @@ -197,7 +195,7 @@ def create_destination_and_install(source: Source, resource_path: Path) -> None: if create_destination: name = source.name() if not name: - raise ConfigurationException("Unable to get source name") + raise ConfigurationException("Unable to get source name.") destination = resource_path / name destination.mkdir() diff --git a/src/mlia/backend/system.py b/src/mlia/backend/system.py index 469083e..ff85bf3 100644 --- a/src/mlia/backend/system.py +++ b/src/mlia/backend/system.py @@ -6,9 +6,7 @@ from typing import Any from typing import cast from typing import Dict from typing import List -from typing import Optional from typing import Tuple -from typing import Union from mlia.backend.common import Backend from mlia.backend.common import ConfigurationException @@ -17,72 +15,12 @@ from mlia.backend.common import get_backend_directories from mlia.backend.common import load_config from mlia.backend.common import remove_backend from mlia.backend.config import SystemConfig -from mlia.backend.controller import SystemController -from mlia.backend.controller import SystemControllerSingleInstance from mlia.backend.fs import get_backends_path -from mlia.backend.protocol import ProtocolFactory -from mlia.backend.protocol import SupportsClose -from mlia.backend.protocol import SupportsConnection -from mlia.backend.protocol import SupportsDeploy +from mlia.backend.proc import run_and_wait from mlia.backend.source import create_destination_and_install from mlia.backend.source import get_source -def get_available_systems_directory_names() -> List[str]: - """Return a list of directory names for all avialable systems.""" - return [entry.name for entry in get_backend_directories("systems")] - - -def get_available_systems() -> List["System"]: - """Return a list with all available systems.""" - available_systems = [] - for config_json in get_backend_configs("systems"): - config_entries = cast(List[SystemConfig], (load_config(config_json))) - for config_entry in config_entries: - config_entry["config_location"] = config_json.parent.absolute() - system = load_system(config_entry) - available_systems.append(system) - - return sorted(available_systems, key=lambda system: system.name) - - -def get_system(system_name: str) -> Optional["System"]: - """Return a system instance with the same name passed as argument.""" - available_systems = get_available_systems() - for system in available_systems: - if system_name == system.name: - return system - return None - - -def install_system(source_path: Path) -> None: - """Install new system.""" - try: - source = get_source(source_path) - config = cast(List[SystemConfig], source.config()) - systems_to_install = [load_system(entry) for entry in config] - except Exception as error: - raise ConfigurationException("Unable to read system definition") from error - - if not systems_to_install: - raise ConfigurationException("No system definition found") - - available_systems = get_available_systems() - already_installed = [s for s in systems_to_install if s in available_systems] - if already_installed: - names = [system.name for system in already_installed] - raise ConfigurationException( - "Systems [{}] are already installed".format(",".join(names)) - ) - - create_destination_and_install(source, get_backends_path("systems")) - - -def remove_system(directory_name: str) -> None: - """Remove system.""" - remove_backend(directory_name, "systems") - - class System(Backend): """System class.""" @@ -90,59 +28,33 @@ class System(Backend): """Construct the System class using the dictionary passed.""" super().__init__(config) - self._setup_data_transfer(config) self._setup_reporting(config) - def _setup_data_transfer(self, config: SystemConfig) -> None: - data_transfer_config = config.get("data_transfer") - protocol = ProtocolFactory().get_protocol( - data_transfer_config, cwd=self.config_location - ) - self.protocol = protocol - def _setup_reporting(self, config: SystemConfig) -> None: self.reporting = config.get("reporting") - def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: + def run(self, command: str) -> Tuple[int, bytearray, bytearray]: """ Run command on the system. Returns a tuple: (exit_code, stdout, stderr) """ - return self.protocol.run(command, retry) - - def deploy(self, src: Path, dst: str, retry: bool = True) -> None: - """Deploy files to the system.""" - if isinstance(self.protocol, SupportsDeploy): - self.protocol.deploy(src, dst, retry) - - @property - def supports_deploy(self) -> bool: - """Check if protocol supports deploy operation.""" - return isinstance(self.protocol, SupportsDeploy) - - @property - def connectable(self) -> bool: - """Check if protocol supports connection.""" - return isinstance(self.protocol, SupportsConnection) - - def establish_connection(self) -> bool: - """Establish connection with the system.""" - if not isinstance(self.protocol, SupportsConnection): + cwd = self.config_location + if not isinstance(cwd, Path) or not cwd.is_dir(): raise ConfigurationException( - "System {} does not support connections".format(self.name) + f"System has invalid config location: {cwd}", ) - return self.protocol.establish_connection() + stdout = bytearray() + stderr = bytearray() - def connection_details(self) -> Tuple[str, int]: - """Return connection details.""" - if not isinstance(self.protocol, SupportsConnection): - raise ConfigurationException( - "System {} does not support connections".format(self.name) - ) - - return self.protocol.connection_details() + return run_and_wait( + command, + cwd=cwd, + terminate_on_error=True, + out=stdout, + err=stderr, + ) def __eq__(self, other: object) -> bool: """Overload operator ==.""" @@ -157,7 +69,6 @@ class System(Backend): "type": "system", "name": self.name, "description": self.description, - "data_transfer_protocol": self.protocol.protocol, "commands": self._get_command_details(), "annotations": self.annotations, } @@ -165,88 +76,66 @@ class System(Backend): return output -class StandaloneSystem(System): - """StandaloneSystem class.""" - +def get_available_systems_directory_names() -> List[str]: + """Return a list of directory names for all avialable systems.""" + return [entry.name for entry in get_backend_directories("systems")] -def get_controller( - single_instance: bool, pid_file_path: Optional[Path] = None -) -> SystemController: - """Get system controller.""" - if single_instance: - return SystemControllerSingleInstance(pid_file_path) - return SystemController() +def get_available_systems() -> List[System]: + """Return a list with all available systems.""" + available_systems = [] + for config_json in get_backend_configs("systems"): + config_entries = cast(List[SystemConfig], (load_config(config_json))) + for config_entry in config_entries: + config_entry["config_location"] = config_json.parent.absolute() + system = load_system(config_entry) + available_systems.append(system) + return sorted(available_systems, key=lambda system: system.name) -class ControlledSystem(System): - """ControlledSystem class.""" - def __init__(self, config: SystemConfig): - """Construct the ControlledSystem class using the dictionary passed.""" - super().__init__(config) - self.controller: Optional[SystemController] = None - - def start( - self, - commands: List[str], - single_instance: bool = True, - pid_file_path: Optional[Path] = None, - ) -> None: - """Launch the system.""" - if ( - not isinstance(self.config_location, Path) - or not self.config_location.is_dir() - ): - raise ConfigurationException( - "System {} has wrong config location".format(self.name) - ) +def get_system(system_name: str) -> System: + """Return a system instance with the same name passed as argument.""" + available_systems = get_available_systems() + for system in available_systems: + if system_name == system.name: + return system + raise ConfigurationException(f"System '{system_name}' not found.") - self.controller = get_controller(single_instance, pid_file_path) - self.controller.start(commands, self.config_location) - def is_running(self) -> bool: - """Check if system is running.""" - if not self.controller: - return False +def install_system(source_path: Path) -> None: + """Install new system.""" + try: + source = get_source(source_path) + config = cast(List[SystemConfig], source.config()) + systems_to_install = [load_system(entry) for entry in config] + except Exception as error: + raise ConfigurationException("Unable to read system definition") from error - return self.controller.is_running() + if not systems_to_install: + raise ConfigurationException("No system definition found") - def get_output(self) -> Tuple[str, str]: - """Return system output.""" - if not self.controller: - return "", "" + available_systems = get_available_systems() + already_installed = [s for s in systems_to_install if s in available_systems] + if already_installed: + names = [system.name for system in already_installed] + raise ConfigurationException( + f"Systems [{','.join(names)}] are already installed." + ) - return self.controller.get_output() + create_destination_and_install(source, get_backends_path("systems")) - def stop(self, wait: bool = False) -> None: - """Stop the system.""" - if not self.controller: - raise Exception("System has not been started") - if isinstance(self.protocol, SupportsClose): - try: - self.protocol.close() - except Exception as error: # pylint: disable=broad-except - print(error) - self.controller.stop(wait) +def remove_system(directory_name: str) -> None: + """Remove system.""" + remove_backend(directory_name, "systems") -def load_system(config: SystemConfig) -> Union[StandaloneSystem, ControlledSystem]: +def load_system(config: SystemConfig) -> System: """Load system based on it's execution type.""" - data_transfer = config.get("data_transfer", {}) - protocol = data_transfer.get("protocol") populate_shared_params(config) - if protocol == "ssh": - return ControlledSystem(config) - - if protocol == "local": - return StandaloneSystem(config) - - raise ConfigurationException( - "Unsupported execution type for protocol {}".format(protocol) - ) + return System(config) def populate_shared_params(config: SystemConfig) -> None: @@ -264,7 +153,7 @@ def populate_shared_params(config: SystemConfig) -> None: raise ConfigurationException("All shared parameters should have aliases") commands = config.get("commands", {}) - for cmd_name in ["build", "run"]: + for cmd_name in ["run"]: command = commands.get(cmd_name) if command is None: commands[cmd_name] = [] @@ -275,7 +164,7 @@ def populate_shared_params(config: SystemConfig) -> None: only_aliases = all(p.get("alias") for p in cmd_user_params) if not only_aliases: raise ConfigurationException( - "All parameters for command {} should have aliases".format(cmd_name) + f"All parameters for command {cmd_name} should have aliases." ) merged_by_alias = { **{p.get("alias"): p for p in shared_user_params}, diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py index 1b75bb4..9006602 100644 --- a/src/mlia/core/reporting.py +++ b/src/mlia/core/reporting.py @@ -125,7 +125,7 @@ class Cell: """Return cell value.""" if self.fmt: if isinstance(self.fmt.str_fmt, str): - return "{:{fmt}}".format(self.value, fmt=self.fmt.str_fmt) + return f"{self.value:{self.fmt.str_fmt}}" if callable(self.fmt.str_fmt): return self.fmt.str_fmt(self.value) diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py index 0fe36e0..2252c6b 100644 --- a/src/mlia/nn/tensorflow/tflite_metrics.py +++ b/src/mlia/nn/tensorflow/tflite_metrics.py @@ -194,9 +194,7 @@ class TFLiteMetrics: aggregation_func = cluster_hist else: - raise NotImplementedError( - "ReportClusterMode '{}' not implemented.".format(mode) - ) + raise NotImplementedError(f"ReportClusterMode '{mode}' not implemented.") uniques = { name: aggregation_func(details) for name, details in self.filtered_details.items() @@ -217,10 +215,10 @@ class TFLiteMetrics: verbose: bool = False, ) -> None: """Print a summary of all the model information.""" - print("Model file: {}".format(self.tflite_file)) + print(f"Model file: {self.tflite_file}") print("#" * 80) print(" " * 28 + "### TFLITE SUMMARY ###") - print("File: {}".format(os.path.abspath(self.tflite_file))) + print(f"File: {os.path.abspath(self.tflite_file)}") print("Input(s):") self._print_in_outs(self.interpreter.get_input_details(), verbose) print("Output(s):") @@ -242,11 +240,11 @@ class TFLiteMetrics: ] if report_sparsity: sparsity = calculate_sparsity(weights, sparsity_accumulator) - row.append("{:.2f}".format(sparsity)) + row.append(f"{sparsity:.2f}") rows.append(row) if verbose: # Print cluster centroids - print("{} cluster centroids:".format(name)) + print(f"{name} cluster centroids:") # Types need to be ignored for this function call because # np.unique does not have type annotation while the # current context does. @@ -259,9 +257,9 @@ class TFLiteMetrics: sparsity_accumulator.total_weights ) if report_sparsity: - summary_row[header.index("Sparsity")] = "{:.2f}".format( - sparsity_accumulator.sparsity() - ) + summary_row[ + header.index("Sparsity") + ] = f"{sparsity_accumulator.sparsity():.2f}" rows.append(summary_row) # Report detailed cluster info if report_cluster_mode is not None: @@ -272,7 +270,7 @@ class TFLiteMetrics: def _print_cluster_details( self, report_cluster_mode: ReportClusterMode, max_num_clusters: int ) -> None: - print("{}:\n{}".format(report_cluster_mode.name, report_cluster_mode.value)) + print(f"{report_cluster_mode.name}:\n{report_cluster_mode.value}") num_clusters = self.num_unique_weights(report_cluster_mode) if ( report_cluster_mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM @@ -283,11 +281,9 @@ class TFLiteMetrics: # histogram for unclustered layers. for name, value in num_clusters.items(): if len(value) > max_num_clusters: - num_clusters[name] = "More than {} unique values.".format( - max_num_clusters - ) + num_clusters[name] = f"More than {max_num_clusters} unique values." for name, nums in num_clusters.items(): - print("- {}: {}".format(self._prettify_name(name), nums)) + print(f"- {self._prettify_name(name)}: {nums}") @staticmethod def _print_in_outs(ios: List[dict], verbose: bool = False) -> None: @@ -296,9 +292,6 @@ class TFLiteMetrics: pprint(item) else: print( - "- {} ({}): {}".format( - item["name"], - np.dtype(item["dtype"]).name, - item["shape"], - ) + f"- {item['name']} ({np.dtype(item['dtype']).name}): " + f"{item['shape']}" ) diff --git a/src/mlia/resources/aiet/systems/SYSTEMS.txt b/src/mlia/resources/backend_configs/systems/SYSTEMS.txt index 3861769..3861769 100644 --- a/src/mlia/resources/aiet/systems/SYSTEMS.txt +++ b/src/mlia/resources/backend_configs/systems/SYSTEMS.txt diff --git a/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json b/src/mlia/resources/backend_configs/systems/corstone-300-vht/backend-config.json index 3ffa548..5c44ebc 100644 --- a/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json +++ b/src/mlia/resources/backend_configs/systems/corstone-300-vht/backend-config.json @@ -7,10 +7,6 @@ "sim_type": "FM", "variant": "Cortex-M55+Ethos-U55" }, - "data_transfer": { - "protocol": "local" - }, - "lock": true, "commands": { "run": [ "/opt/VHT/VHT_Corstone_SSE-300_Ethos-U55 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat" @@ -47,10 +43,6 @@ "sim_type": "FM", "variant": "Cortex-M55+Ethos-U65" }, - "data_transfer": { - "protocol": "local" - }, - "lock": true, "commands": { "run": [ "/opt/VHT/VHT_Corstone_SSE-300_Ethos-U65 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat" diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license b/src/mlia/resources/backend_configs/systems/corstone-300-vht/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license +++ b/src/mlia/resources/backend_configs/systems/corstone-300-vht/backend-config.json.license diff --git a/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json b/src/mlia/resources/backend_configs/systems/corstone-300/backend-config.json index 6d6785d..41d2fd0 100644 --- a/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json +++ b/src/mlia/resources/backend_configs/systems/corstone-300/backend-config.json @@ -7,10 +7,6 @@ "sim_type": "FM", "variant": "Cortex-M55+Ethos-U55" }, - "data_transfer": { - "protocol": "local" - }, - "lock": true, "commands": { "run": [ "FVP_Corstone_SSE-300_Ethos-U55 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat" @@ -47,10 +43,6 @@ "sim_type": "FM", "variant": "Cortex-M55+Ethos-U65" }, - "data_transfer": { - "protocol": "local" - }, - "lock": true, "commands": { "run": [ "FVP_Corstone_SSE-300_Ethos-U65 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat" diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license b/src/mlia/resources/backend_configs/systems/corstone-300/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license +++ b/src/mlia/resources/backend_configs/systems/corstone-300/backend-config.json.license diff --git a/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json b/src/mlia/resources/backend_configs/systems/corstone-310-vht/backend-config.json index dbc2622..3ea9a6a 100644 --- a/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json +++ b/src/mlia/resources/backend_configs/systems/corstone-310-vht/backend-config.json @@ -7,10 +7,6 @@ "sim_type": "FM", "variant": "Cortex-M85+Ethos-U55" }, - "data_transfer": { - "protocol": "local" - }, - "lock": true, "commands": { "run": [ "/opt/VHT/VHT_Corstone_SSE-310 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat" diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license b/src/mlia/resources/backend_configs/systems/corstone-310-vht/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license +++ b/src/mlia/resources/backend_configs/systems/corstone-310-vht/backend-config.json.license diff --git a/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json b/src/mlia/resources/backend_configs/systems/corstone-310/backend-config.json index 7aa3b0a..d043a2d 100644 --- a/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json +++ b/src/mlia/resources/backend_configs/systems/corstone-310/backend-config.json @@ -7,10 +7,6 @@ "sim_type": "FM", "variant": "Cortex-M85+Ethos-U55" }, - "data_transfer": { - "protocol": "local" - }, - "lock": true, "commands": { "run": [ "FVP_Corstone_SSE-310 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat" diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license b/src/mlia/resources/backend_configs/systems/corstone-310/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license +++ b/src/mlia/resources/backend_configs/systems/corstone-310/backend-config.json.license diff --git a/src/mlia/resources/backends/applications/.gitignore b/src/mlia/resources/backends/applications/.gitignore deleted file mode 100644 index 0226166..0000000 --- a/src/mlia/resources/backends/applications/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# Ignore everything in this directory -* -# Except this file -!.gitignore diff --git a/src/mlia/resources/aiet/applications/APPLICATIONS.txt b/src/mlia/resources/backends/applications/APPLICATIONS.txt index a702e19..ca1003b 100644 --- a/src/mlia/resources/aiet/applications/APPLICATIONS.txt +++ b/src/mlia/resources/backends/applications/APPLICATIONS.txt @@ -4,4 +4,4 @@ SPDX-License-Identifier: Apache-2.0 This directory contains the application packages for the Generic Inference Runner. -Each package should contain its own aiet-config.json file. +Each package should contain its own backend-config.json file. diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/backend-config.json index 757ccd1..7ee5e00 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json +++ b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/backend-config.json @@ -10,7 +10,6 @@ "name": "Corstone-300: Cortex-M55+Ethos-U65" } ], - "lock": true, "variables": { "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf" } diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license +++ b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/backend-config.json.license diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf Binary files differindex 4c50e1f..4c50e1f 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf +++ b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license index 8896f92..8896f92 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license +++ b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/backend-config.json index cb7e113..51ff429 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json +++ b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/backend-config.json @@ -7,7 +7,6 @@ "name": "Corstone-300: Cortex-M55+Ethos-U55" } ], - "lock": true, "variables": { "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf" } diff --git a/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license +++ b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/backend-config.json.license diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf Binary files differindex 850e2eb..850e2eb 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf +++ b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license index 8896f92..8896f92 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license +++ b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/backend-config.json index d524f64..b59c85e 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json +++ b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/backend-config.json @@ -7,7 +7,6 @@ "name": "Corstone-300: Cortex-M55+Ethos-U65" } ], - "lock": true, "variables": { "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf" } diff --git a/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license +++ b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/backend-config.json.license diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf Binary files differindex f881bb8..f881bb8 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf +++ b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license index 8896f92..8896f92 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license +++ b/src/mlia/resources/backends/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/backend-config.json index 2cbab70..69c5e60 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json +++ b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/backend-config.json @@ -7,7 +7,6 @@ "name": "Corstone-310: Cortex-M85+Ethos-U55" } ], - "lock": true, "variables": { "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf" } diff --git a/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license +++ b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/backend-config.json.license diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf Binary files differindex 846ee33..846ee33 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf +++ b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license index 8896f92..8896f92 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license +++ b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/backend-config.json index 01bec74..fbe4a16 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json +++ b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/backend-config.json @@ -7,7 +7,6 @@ "name": "Corstone-310: Cortex-M85+Ethos-U55" } ], - "lock": true, "variables": { "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf" } diff --git a/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license +++ b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/backend-config.json.license diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf Binary files differindex e3eab97..e3eab97 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf +++ b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license index 8896f92..8896f92 100644 --- a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license +++ b/src/mlia/resources/backends/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license diff --git a/src/mlia/tools/metadata/corstone.py b/src/mlia/tools/metadata/corstone.py index a92f81c..6a3c1c8 100644 --- a/src/mlia/tools/metadata/corstone.py +++ b/src/mlia/tools/metadata/corstone.py @@ -13,7 +13,6 @@ from typing import List from typing import Optional import mlia.backend.manager as backend_manager -from mlia.backend.fs import get_backend_resources from mlia.tools.metadata.common import DownloadAndInstall from mlia.tools.metadata.common import Installation from mlia.tools.metadata.common import InstallationType @@ -24,7 +23,7 @@ from mlia.utils.filesystem import all_paths_valid from mlia.utils.filesystem import copy_all from mlia.utils.filesystem import get_mlia_resources from mlia.utils.filesystem import temp_directory -from mlia.utils.proc import working_directory +from mlia.utils.filesystem import working_directory logger = logging.getLogger(__name__) @@ -76,7 +75,7 @@ class BackendMetadata: """Return list of expected resources.""" resources = [self.system_config, *self.apps_resources] - return (get_backend_resources() / resource for resource in resources) + return (get_mlia_resources() / resource for resource in resources) @property def supported_platform(self) -> bool: @@ -314,12 +313,8 @@ def get_corstone_300_installation() -> Installation: metadata=BackendMetadata( name="Corstone-300", description="Corstone-300 FVP", - system_config="aiet/systems/corstone-300/aiet-config.json", - apps_resources=[ - "aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA", - "aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA", - "aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA", - ], + system_config="backend_configs/systems/corstone-300/backend-config.json", + apps_resources=[], fvp_dir_name="corstone_300", download_artifact=DownloadArtifact( name="Corstone-300 FVP", @@ -346,7 +341,9 @@ def get_corstone_300_installation() -> Installation: "VHT_Corstone_SSE-300_Ethos-U65", ], copy_source=False, - system_config="aiet/systems/corstone-300-vht/aiet-config.json", + system_config=( + "backends_configs/systems/corstone-300-vht/backend-config.json" + ), ), ), backend_installer=Corstone300Installer(), @@ -363,11 +360,8 @@ def get_corstone_310_installation() -> Installation: metadata=BackendMetadata( name="Corstone-310", description="Corstone-310 FVP", - system_config="aiet/systems/corstone-310/aiet-config.json", - apps_resources=[ - "aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA", - "aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA", - ], + system_config="backend_configs/systems/corstone-310/backend-config.json", + apps_resources=[], fvp_dir_name="corstone_310", download_artifact=None, supported_platforms=["Linux"], @@ -386,7 +380,9 @@ def get_corstone_310_installation() -> Installation: "VHT_Corstone_SSE-310", ], copy_source=False, - system_config="aiet/systems/corstone-310-vht/aiet-config.json", + system_config=( + "backend_configs/systems/corstone-310-vht/backend-config.json" + ), ), ), backend_installer=None, diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py index 73a88d9..7975905 100644 --- a/src/mlia/utils/filesystem.py +++ b/src/mlia/utils/filesystem.py @@ -122,3 +122,21 @@ def copy_all(*paths: Path, dest: Path) -> None: if path.is_dir(): shutil.copytree(path, dest, dirs_exist_ok=True) + + +@contextmanager +def working_directory( + working_dir: Path, create_dir: bool = False +) -> Generator[Path, None, None]: + """Temporary change working directory.""" + current_working_dir = Path.cwd() + + if create_dir: + working_dir.mkdir() + + os.chdir(working_dir) + + try: + yield working_dir + finally: + os.chdir(current_working_dir) diff --git a/src/mlia/utils/proc.py b/src/mlia/utils/proc.py deleted file mode 100644 index 18a4305..0000000 --- a/src/mlia/utils/proc.py +++ /dev/null @@ -1,152 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Utils related to process management.""" -import os -import signal -import subprocess -import time -from abc import ABC -from abc import abstractmethod -from contextlib import contextmanager -from pathlib import Path -from typing import Any -from typing import Generator -from typing import Iterable -from typing import List -from typing import Optional -from typing import Tuple - - -class OutputConsumer(ABC): - """Base class for the output consumers.""" - - @abstractmethod - def feed(self, line: str) -> None: - """Feed new line to the consumer.""" - - -class RunningCommand: - """Running command.""" - - def __init__(self, process: subprocess.Popen) -> None: - """Init running command instance.""" - self.process = process - self.output_consumers: List[OutputConsumer] = [] - - def is_alive(self) -> bool: - """Return true if process is still alive.""" - return self.process.poll() is None - - def exit_code(self) -> Optional[int]: - """Return process's return code.""" - return self.process.poll() - - def stdout(self) -> Iterable[str]: - """Return std output of the process.""" - assert self.process.stdout is not None - - for line in self.process.stdout: - yield line - - def kill(self) -> None: - """Kill the process.""" - self.process.kill() - - def send_signal(self, signal_num: int) -> None: - """Send signal to the process.""" - self.process.send_signal(signal_num) - - def consume_output(self) -> None: - """Pass program's output to the consumers.""" - if self.process is None or not self.output_consumers: - return - - for line in self.stdout(): - for consumer in self.output_consumers: - consumer.feed(line) - - def stop( - self, wait: bool = True, num_of_attempts: int = 5, interval: float = 0.5 - ) -> None: - """Stop execution.""" - try: - if not self.is_alive(): - return - - self.process.send_signal(signal.SIGINT) - self.consume_output() - - if not wait: - return - - for _ in range(num_of_attempts): - time.sleep(interval) - if not self.is_alive(): - break - else: - raise Exception("Unable to stop running command") - finally: - self._close_fd() - - def _close_fd(self) -> None: - """Close file descriptors.""" - - def close(file_descriptor: Any) -> None: - """Check and close file.""" - if file_descriptor is not None and hasattr(file_descriptor, "close"): - file_descriptor.close() - - close(self.process.stdout) - close(self.process.stderr) - - def wait(self, redirect_output: bool = False) -> None: - """Redirect process output to stdout and wait for completion.""" - if redirect_output: - for line in self.stdout(): - print(line, end="") - - self.process.wait() - - -class CommandExecutor: - """Command executor.""" - - @staticmethod - def execute(command: List[str]) -> Tuple[int, bytes, bytes]: - """Execute the command.""" - result = subprocess.run( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True - ) - - return (result.returncode, result.stdout, result.stderr) - - @staticmethod - def submit(command: List[str]) -> RunningCommand: - """Submit command for the execution.""" - process = subprocess.Popen( # pylint: disable=consider-using-with - command, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, # redirect command stderr to stdout - universal_newlines=True, - bufsize=1, - ) - - return RunningCommand(process) - - -@contextmanager -def working_directory( - working_dir: Path, create_dir: bool = False -) -> Generator[Path, None, None]: - """Temporary change working directory.""" - current_working_dir = Path.cwd() - - if create_dir: - working_dir.mkdir() - - os.chdir(working_dir) - - try: - yield working_dir - finally: - os.chdir(current_working_dir) diff --git a/tests/conftest.py b/tests/conftest.py index 5c6156c..4d12033 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,12 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 """Pytest conf module.""" import shutil +import tarfile from pathlib import Path +from typing import Any from typing import Generator import pytest import tensorflow as tf +from mlia.core.context import ExecutionContext from mlia.devices.ethosu.config import EthosUConfiguration from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import save_keras_model @@ -15,6 +18,106 @@ from mlia.nn.tensorflow.utils import save_tflite_model from mlia.tools.vela_wrapper import optimize_model +@pytest.fixture(scope="session", name="test_resources_path") +def fixture_test_resources_path() -> Path: + """Return test resources path.""" + return Path(__file__).parent / "test_resources" + + +@pytest.fixture(name="dummy_context") +def fixture_dummy_context(tmpdir: str) -> ExecutionContext: + """Return dummy context fixture.""" + return ExecutionContext(working_dir=tmpdir) + + +@pytest.fixture(scope="session") +def test_systems_path(test_resources_path: Path) -> Path: + """Return test systems path in a pytest fixture.""" + return test_resources_path / "backends" / "systems" + + +@pytest.fixture(scope="session") +def test_applications_path(test_resources_path: Path) -> Path: + """Return test applications path in a pytest fixture.""" + return test_resources_path / "backends" / "applications" + + +@pytest.fixture(scope="session") +def non_optimised_input_model_file(test_tflite_model: Path) -> Path: + """Provide the path to a quantized dummy model file.""" + return test_tflite_model + + +@pytest.fixture(scope="session") +def optimised_input_model_file(test_tflite_vela_model: Path) -> Path: + """Provide path to Vela-optimised dummy model file.""" + return test_tflite_vela_model + + +@pytest.fixture(scope="session") +def invalid_input_model_file(test_tflite_invalid_model: Path) -> Path: + """Provide the path to an invalid dummy model file.""" + return test_tflite_invalid_model + + +@pytest.fixture(autouse=True) +def test_resources(monkeypatch: pytest.MonkeyPatch, test_resources_path: Path) -> Any: + """Force using test resources as middleware's repository.""" + + def get_test_resources() -> Path: + """Return path to the test resources.""" + return test_resources_path / "backends" + + monkeypatch.setattr("mlia.backend.fs.get_backend_resources", get_test_resources) + yield + + +def create_archive( + archive_name: str, source: Path, destination: Path, with_root_folder: bool = False +) -> None: + """Create archive from directory source.""" + with tarfile.open(destination / archive_name, mode="w:gz") as tar: + for item in source.iterdir(): + item_name = item.name + if with_root_folder: + item_name = f"{source.name}/{item_name}" + tar.add(item, item_name) + + +def process_directory(source: Path, destination: Path) -> None: + """Process resource directory.""" + destination.mkdir() + + for item in source.iterdir(): + if item.is_dir(): + create_archive(f"{item.name}.tar.gz", item, destination) + create_archive(f"{item.name}_dir.tar.gz", item, destination, True) + + +@pytest.fixture(scope="session", autouse=True) +def add_archives( + test_resources_path: Path, tmp_path_factory: pytest.TempPathFactory +) -> Any: + """Generate archives of the test resources.""" + tmp_path = tmp_path_factory.mktemp("archives") + + archives_path = tmp_path / "archives" + archives_path.mkdir() + + if (archives_path_link := test_resources_path / "archives").is_symlink(): + archives_path_link.unlink() + + archives_path_link.symlink_to(archives_path, target_is_directory=True) + + for item in ["applications", "systems"]: + process_directory(test_resources_path / "backends" / item, archives_path / item) + + yield + + archives_path_link.unlink() + shutil.rmtree(tmp_path) + + def get_test_keras_model() -> tf.keras.Model: """Return test Keras model.""" model = tf.keras.Sequential( diff --git a/tests/mlia/__init__.py b/tests/mlia/__init__.py deleted file mode 100644 index 0687f14..0000000 --- a/tests/mlia/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""MLIA tests module.""" diff --git a/tests/mlia/conftest.py b/tests/mlia/conftest.py deleted file mode 100644 index 0b4b2aa..0000000 --- a/tests/mlia/conftest.py +++ /dev/null @@ -1,111 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Pytest conf module.""" -import shutil -import tarfile -from pathlib import Path -from typing import Any - -import pytest - -from mlia.core.context import ExecutionContext - - -@pytest.fixture(scope="session", name="test_resources_path") -def fixture_test_resources_path() -> Path: - """Return test resources path.""" - return Path(__file__).parent / "test_resources" - - -@pytest.fixture(name="dummy_context") -def fixture_dummy_context(tmpdir: str) -> ExecutionContext: - """Return dummy context fixture.""" - return ExecutionContext(working_dir=tmpdir) - - -@pytest.fixture(scope="session") -def test_systems_path(test_resources_path: Path) -> Path: - """Return test systems path in a pytest fixture.""" - return test_resources_path / "backends" / "systems" - - -@pytest.fixture(scope="session") -def test_applications_path(test_resources_path: Path) -> Path: - """Return test applications path in a pytest fixture.""" - return test_resources_path / "backends" / "applications" - - -@pytest.fixture(scope="session") -def non_optimised_input_model_file(test_tflite_model: Path) -> Path: - """Provide the path to a quantized dummy model file.""" - return test_tflite_model - - -@pytest.fixture(scope="session") -def optimised_input_model_file(test_tflite_vela_model: Path) -> Path: - """Provide path to Vela-optimised dummy model file.""" - return test_tflite_vela_model - - -@pytest.fixture(scope="session") -def invalid_input_model_file(test_tflite_invalid_model: Path) -> Path: - """Provide the path to an invalid dummy model file.""" - return test_tflite_invalid_model - - -@pytest.fixture(autouse=True) -def test_resources(monkeypatch: pytest.MonkeyPatch, test_resources_path: Path) -> Any: - """Force using test resources as middleware's repository.""" - - def get_test_resources() -> Path: - """Return path to the test resources.""" - return test_resources_path / "backends" - - monkeypatch.setattr("mlia.backend.fs.get_backend_resources", get_test_resources) - yield - - -def create_archive( - archive_name: str, source: Path, destination: Path, with_root_folder: bool = False -) -> None: - """Create archive from directory source.""" - with tarfile.open(destination / archive_name, mode="w:gz") as tar: - for item in source.iterdir(): - item_name = item.name - if with_root_folder: - item_name = f"{source.name}/{item_name}" - tar.add(item, item_name) - - -def process_directory(source: Path, destination: Path) -> None: - """Process resource directory.""" - destination.mkdir() - - for item in source.iterdir(): - if item.is_dir(): - create_archive(f"{item.name}.tar.gz", item, destination) - create_archive(f"{item.name}_dir.tar.gz", item, destination, True) - - -@pytest.fixture(scope="session", autouse=True) -def add_archives( - test_resources_path: Path, tmp_path_factory: pytest.TempPathFactory -) -> Any: - """Generate archives of the test resources.""" - tmp_path = tmp_path_factory.mktemp("archives") - - archives_path = tmp_path / "archives" - archives_path.mkdir() - - if (archives_path_link := test_resources_path / "archives").is_symlink(): - archives_path_link.unlink() - - archives_path_link.symlink_to(archives_path, target_is_directory=True) - - for item in ["applications", "systems"]: - process_directory(test_resources_path / "backends" / item, archives_path / item) - - yield - - archives_path_link.unlink() - shutil.rmtree(tmp_path) diff --git a/tests/mlia/test_backend_controller.py b/tests/mlia/test_backend_controller.py deleted file mode 100644 index a047adf..0000000 --- a/tests/mlia/test_backend_controller.py +++ /dev/null @@ -1,160 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tests for system controller.""" -import csv -import os -import time -from pathlib import Path -from typing import Any - -import psutil -import pytest - -from mlia.backend.common import ConfigurationException -from mlia.backend.controller import SystemController -from mlia.backend.controller import SystemControllerSingleInstance -from mlia.backend.proc import ShellCommand - - -def get_system_controller(**kwargs: Any) -> SystemController: - """Get service controller.""" - single_instance = kwargs.get("single_instance", False) - if single_instance: - pid_file_path = kwargs.get("pid_file_path") - return SystemControllerSingleInstance(pid_file_path) - - return SystemController() - - -def test_service_controller() -> None: - """Test service controller functionality.""" - service_controller = get_system_controller() - - assert service_controller.get_output() == ("", "") - with pytest.raises(ConfigurationException, match="Wrong working directory"): - service_controller.start(["sleep 100"], Path("unknown")) - - service_controller.start(["sleep 100"], Path.cwd()) - assert service_controller.is_running() - - service_controller.stop(True) - assert not service_controller.is_running() - assert service_controller.get_output() == ("", "") - - service_controller.stop() - - with pytest.raises( - ConfigurationException, match="System should have only one command to run" - ): - service_controller.start(["sleep 100", "sleep 101"], Path.cwd()) - - with pytest.raises(ConfigurationException, match="No startup command provided"): - service_controller.start([""], Path.cwd()) - - -def test_service_controller_bad_configuration() -> None: - """Test service controller functionality for bad configuration.""" - with pytest.raises(Exception, match="No pid file path presented"): - service_controller = get_system_controller( - single_instance=True, pid_file_path=None - ) - service_controller.start(["sleep 100"], Path.cwd()) - - -def test_service_controller_writes_process_info_correctly(tmpdir: Any) -> None: - """Test that controller writes process info correctly.""" - pid_file = Path(tmpdir) / "test.pid" - - service_controller = get_system_controller( - single_instance=True, pid_file_path=Path(tmpdir) / "test.pid" - ) - - service_controller.start(["sleep 100"], Path.cwd()) - assert service_controller.is_running() - assert pid_file.is_file() - - with open(pid_file, "r", encoding="utf-8") as file: - csv_reader = csv.reader(file) - rows = list(csv_reader) - assert len(rows) == 1 - - name, *_ = rows[0] - assert name == "sleep" - - service_controller.stop() - assert pid_file.exists() - - -def test_service_controller_does_not_write_process_info_if_process_finishes( - tmpdir: Any, -) -> None: - """Test that controller does not write process info if process already finished.""" - pid_file = Path(tmpdir) / "test.pid" - service_controller = get_system_controller( - single_instance=True, pid_file_path=pid_file - ) - service_controller.is_running = lambda: False # type: ignore - service_controller.start(["echo hello"], Path.cwd()) - - assert not pid_file.exists() - - -def test_service_controller_searches_for_previous_instances_correctly( - tmpdir: Any, -) -> None: - """Test that controller searches for previous instances correctly.""" - pid_file = Path(tmpdir) / "test.pid" - command = ShellCommand().run("sleep", "100") - assert command.is_alive() - - pid = command.process.pid - process = psutil.Process(pid) - with open(pid_file, "w", encoding="utf-8") as file: - csv_writer = csv.writer(file) - csv_writer.writerow(("some_process", "some_program", "some_cwd", os.getpid())) - csv_writer.writerow((process.name(), process.exe(), process.cwd(), process.pid)) - csv_writer.writerow(("some_old_process", "not_running", "from_nowhere", 77777)) - - service_controller = get_system_controller( - single_instance=True, pid_file_path=pid_file - ) - service_controller.start(["sleep 100"], Path.cwd()) - # controller should stop this process as it is currently running and - # mentioned in pid file - assert not command.is_alive() - - service_controller.stop() - - -@pytest.mark.parametrize( - "executable", ["test_backend_run_script.sh", "test_backend_run"] -) -def test_service_controller_run_shell_script( - executable: str, test_resources_path: Path -) -> None: - """Test controller's ability to run shell scripts.""" - script_path = test_resources_path / "scripts" - - service_controller = get_system_controller() - - service_controller.start([executable], script_path) - - assert service_controller.is_running() - # give time for the command to produce output - time.sleep(2) - service_controller.stop(wait=True) - assert not service_controller.is_running() - stdout, stderr = service_controller.get_output() - assert stdout == "Hello from script\n" - assert stderr == "Oops!\n" - - -def test_service_controller_does_nothing_if_not_started(tmpdir: Any) -> None: - """Test that nothing happened if controller is not started.""" - service_controller = get_system_controller( - single_instance=True, pid_file_path=Path(tmpdir) / "test.pid" - ) - - assert not service_controller.is_running() - service_controller.stop() - assert not service_controller.is_running() diff --git a/tests/mlia/test_backend_execution.py b/tests/mlia/test_backend_execution.py deleted file mode 100644 index 9395352..0000000 --- a/tests/mlia/test_backend_execution.py +++ /dev/null @@ -1,518 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=no-self-use -"""Test backend execution module.""" -from contextlib import ExitStack as does_not_raise -from pathlib import Path -from typing import Any -from typing import Dict -from unittest import mock -from unittest.mock import MagicMock - -import pytest -from sh import CommandNotFound - -from mlia.backend.application import Application -from mlia.backend.application import get_application -from mlia.backend.common import DataPaths -from mlia.backend.common import UserParamConfig -from mlia.backend.config import ApplicationConfig -from mlia.backend.config import LocalProtocolConfig -from mlia.backend.config import SystemConfig -from mlia.backend.execution import deploy_data -from mlia.backend.execution import execute_commands_locally -from mlia.backend.execution import ExecutionContext -from mlia.backend.execution import get_application_and_system -from mlia.backend.execution import get_application_by_name_and_system -from mlia.backend.execution import get_file_lock_path -from mlia.backend.execution import ParamResolver -from mlia.backend.execution import Reporter -from mlia.backend.execution import wait -from mlia.backend.output_parser import Base64OutputParser -from mlia.backend.output_parser import OutputParser -from mlia.backend.output_parser import RegexOutputParser -from mlia.backend.proc import CommandFailedException -from mlia.backend.system import get_system -from mlia.backend.system import load_system - - -def test_context_param_resolver(tmpdir: Any) -> None: - """Test parameter resolving.""" - system_config_location = Path(tmpdir) / "system" - system_config_location.mkdir() - - application_config_location = Path(tmpdir) / "application" - application_config_location.mkdir() - - ctx = ExecutionContext( - app=Application( - ApplicationConfig( - name="test_application", - description="Test application", - config_location=application_config_location, - build_dir="build-{application.name}-{system.name}", - commands={ - "run": [ - "run_command1 {user_params:0}", - "run_command2 {user_params:1}", - ] - }, - variables={"var_1": "value for var_1"}, - user_params={ - "run": [ - UserParamConfig( - name="--param1", - description="Param 1", - default_value="123", - alias="param_1", - ), - UserParamConfig( - name="--param2", description="Param 2", default_value="456" - ), - UserParamConfig( - name="--param3", description="Param 3", alias="param_3" - ), - UserParamConfig( - name="--param4=", - description="Param 4", - default_value="456", - alias="param_4", - ), - UserParamConfig( - description="Param 5", - default_value="789", - alias="param_5", - ), - ] - }, - ) - ), - app_params=["--param2=789"], - system=load_system( - SystemConfig( - name="test_system", - description="Test system", - config_location=system_config_location, - build_dir="build", - data_transfer=LocalProtocolConfig(protocol="local"), - commands={ - "build": ["build_command1 {user_params:0}"], - "run": ["run_command {application.commands.run:1}"], - }, - variables={"var_1": "value for var_1"}, - user_params={ - "build": [ - UserParamConfig( - name="--param1", description="Param 1", default_value="aaa" - ), - UserParamConfig(name="--param2", description="Param 2"), - ] - }, - ) - ), - system_params=["--param1=bbb"], - custom_deploy_data=[], - ) - - param_resolver = ParamResolver(ctx) - expected_values = { - "application.name": "test_application", - "application.description": "Test application", - "application.config_dir": str(application_config_location), - "application.build_dir": "{}/build-test_application-test_system".format( - application_config_location - ), - "application.commands.run:0": "run_command1 --param1 123", - "application.commands.run.params:0": "123", - "application.commands.run.params:param_1": "123", - "application.commands.run:1": "run_command2 --param2 789", - "application.commands.run.params:1": "789", - "application.variables:var_1": "value for var_1", - "system.name": "test_system", - "system.description": "Test system", - "system.config_dir": str(system_config_location), - "system.commands.build:0": "build_command1 --param1 bbb", - "system.commands.run:0": "run_command run_command2 --param2 789", - "system.commands.build.params:0": "bbb", - "system.variables:var_1": "value for var_1", - } - - for param, value in expected_values.items(): - assert param_resolver(param) == value - - assert ctx.build_dir() == Path( - "{}/build-test_application-test_system".format(application_config_location) - ) - - expected_errors = { - "application.variables:var_2": pytest.raises( - Exception, match="Unknown variable var_2" - ), - "application.commands.clean:0": pytest.raises( - Exception, match="Command clean not found" - ), - "application.commands.run:2": pytest.raises( - Exception, match="Invalid index 2 for command run" - ), - "application.commands.run.params:5": pytest.raises( - Exception, match="Invalid parameter index 5 for command run" - ), - "application.commands.run.params:param_2": pytest.raises( - Exception, - match="No value for parameter with index or alias param_2 of command run", - ), - "UNKNOWN": pytest.raises( - Exception, match="Unable to resolve parameter UNKNOWN" - ), - "system.commands.build.params:1": pytest.raises( - Exception, - match="No value for parameter with index or alias 1 of command build", - ), - "system.commands.build:A": pytest.raises( - Exception, match="Bad command index A" - ), - "system.variables:var_2": pytest.raises( - Exception, match="Unknown variable var_2" - ), - } - for param, error in expected_errors.items(): - with error: - param_resolver(param) - - resolved_params = ctx.app.resolved_parameters("run", []) - expected_user_params = { - "user_params:0": "--param1 123", - "user_params:param_1": "--param1 123", - "user_params:2": "--param3", - "user_params:param_3": "--param3", - "user_params:3": "--param4=456", - "user_params:param_4": "--param4=456", - "user_params:param_5": "789", - } - for param, expected_value in expected_user_params.items(): - assert param_resolver(param, "run", resolved_params) == expected_value - - with pytest.raises( - Exception, match="Invalid index 5 for user params of command run" - ): - param_resolver("user_params:5", "run", resolved_params) - - with pytest.raises( - Exception, match="No user parameter for command 'run' with alias 'param_2'." - ): - param_resolver("user_params:param_2", "run", resolved_params) - - with pytest.raises(Exception, match="Unable to resolve user params"): - param_resolver("user_params:0", "", resolved_params) - - bad_ctx = ExecutionContext( - app=Application( - ApplicationConfig( - name="test_application", - config_location=application_config_location, - build_dir="build-{user_params:0}", - ) - ), - app_params=["--param2=789"], - system=load_system( - SystemConfig( - name="test_system", - description="Test system", - config_location=system_config_location, - build_dir="build-{system.commands.run.params:123}", - data_transfer=LocalProtocolConfig(protocol="local"), - ) - ), - system_params=["--param1=bbb"], - custom_deploy_data=[], - ) - param_resolver = ParamResolver(bad_ctx) - with pytest.raises(Exception, match="Unable to resolve user params"): - bad_ctx.build_dir() - - -# pylint: disable=too-many-arguments -@pytest.mark.parametrize( - "application_name, soft_lock, sys_lock, lock_dir, expected_error, expected_path", - ( - ( - "test_application", - True, - True, - Path("/tmp"), - does_not_raise(), - Path("/tmp/middleware_test_application_test_system.lock"), - ), - ( - "$$test_application$!:", - True, - True, - Path("/tmp"), - does_not_raise(), - Path("/tmp/middleware_test_application_test_system.lock"), - ), - ( - "test_application", - True, - True, - Path("unknown"), - pytest.raises( - Exception, match="Invalid directory unknown for lock files provided" - ), - None, - ), - ( - "test_application", - False, - True, - Path("/tmp"), - does_not_raise(), - Path("/tmp/middleware_test_system.lock"), - ), - ( - "test_application", - True, - False, - Path("/tmp"), - does_not_raise(), - Path("/tmp/middleware_test_application.lock"), - ), - ( - "test_application", - False, - False, - Path("/tmp"), - pytest.raises(Exception, match="No filename for lock provided"), - None, - ), - ), -) -def test_get_file_lock_path( - application_name: str, - soft_lock: bool, - sys_lock: bool, - lock_dir: Path, - expected_error: Any, - expected_path: Path, -) -> None: - """Test get_file_lock_path function.""" - with expected_error: - ctx = ExecutionContext( - app=Application(ApplicationConfig(name=application_name, lock=soft_lock)), - app_params=[], - system=load_system( - SystemConfig( - name="test_system", - lock=sys_lock, - data_transfer=LocalProtocolConfig(protocol="local"), - ) - ), - system_params=[], - custom_deploy_data=[], - ) - path = get_file_lock_path(ctx, lock_dir) - assert path == expected_path - - -def test_get_application_by_name_and_system(monkeypatch: Any) -> None: - """Test exceptional case for get_application_by_name_and_system.""" - monkeypatch.setattr( - "mlia.backend.execution.get_application", - MagicMock(return_value=[MagicMock(), MagicMock()]), - ) - - with pytest.raises( - ValueError, - match="Error during getting application test_application for the " - "system test_system", - ): - get_application_by_name_and_system("test_application", "test_system") - - -def test_get_application_and_system(monkeypatch: Any) -> None: - """Test exceptional case for get_application_and_system.""" - monkeypatch.setattr( - "mlia.backend.execution.get_system", MagicMock(return_value=None) - ) - - with pytest.raises(ValueError, match="System test_system is not found"): - get_application_and_system("test_application", "test_system") - - -def test_wait_function(monkeypatch: Any) -> None: - """Test wait function.""" - sleep_mock = MagicMock() - monkeypatch.setattr("time.sleep", sleep_mock) - wait(0.1) - sleep_mock.assert_called_once() - - -def test_deployment_execution_context() -> None: - """Test property 'is_deploy_needed' of the ExecutionContext.""" - ctx = ExecutionContext( - app=get_application("application_1")[0], - app_params=[], - system=get_system("System 1"), - system_params=[], - ) - assert not ctx.is_deploy_needed - deploy_data(ctx) # should be a NOP - - ctx = ExecutionContext( - app=get_application("application_1")[0], - app_params=[], - system=get_system("System 1"), - system_params=[], - custom_deploy_data=[DataPaths(Path("README.md"), ".")], - ) - assert ctx.is_deploy_needed - - ctx = ExecutionContext( - app=get_application("application_1")[0], - app_params=[], - system=None, - system_params=[], - ) - assert not ctx.is_deploy_needed - with pytest.raises(AssertionError): - deploy_data(ctx) - - -def test_reporter_execution_context(tmp_path: Path) -> None: - """Test ExecutionContext creates a reporter when a report file is provided.""" - # Configure regex parser for the system manually - sys = get_system("System 1") - assert sys is not None - sys.reporting = { - "regex": { - "simulated_time": {"pattern": "Simulated time.*: (.*)s", "type": "float"} - } - } - report_file_path = tmp_path / "test_report.json" - - ctx = ExecutionContext( - app=get_application("application_1")[0], - app_params=[], - system=sys, - system_params=[], - report_file=report_file_path, - ) - assert isinstance(ctx.reporter, Reporter) - assert len(ctx.reporter.parsers) == 2 - assert any(isinstance(parser, RegexOutputParser) for parser in ctx.reporter.parsers) - assert any( - isinstance(parser, Base64OutputParser) for parser in ctx.reporter.parsers - ) - - -class TestExecuteCommandsLocally: - """Test execute_commands_locally() function.""" - - @pytest.mark.parametrize( - "first_command, exception, expected_output", - ( - ( - "echo 'hello'", - None, - "Running: echo 'hello'\nhello\nRunning: echo 'goodbye'\ngoodbye\n", - ), - ( - "non-existent-command", - CommandNotFound, - "Running: non-existent-command\n", - ), - ("false", CommandFailedException, "Running: false\n"), - ), - ids=( - "runs_multiple_commands", - "stops_executing_on_non_existent_command", - "stops_executing_when_command_exits_with_error_code", - ), - ) - def test_execution( - self, - first_command: str, - exception: Any, - expected_output: str, - test_resources_path: Path, - capsys: Any, - ) -> None: - """Test expected behaviour of the function.""" - commands = [first_command, "echo 'goodbye'"] - cwd = test_resources_path - if exception is None: - execute_commands_locally(commands, cwd) - else: - with pytest.raises(exception): - execute_commands_locally(commands, cwd) - - captured = capsys.readouterr() - assert captured.out == expected_output - - def test_stops_executing_on_exception( - self, monkeypatch: Any, test_resources_path: Path - ) -> None: - """Ensure commands following an error-exit-code command don't run.""" - # Mock execute_command() function - execute_command_mock = mock.MagicMock() - monkeypatch.setattr("mlia.backend.proc.execute_command", execute_command_mock) - - # Mock Command object and assign as return value to execute_command() - cmd_mock = mock.MagicMock() - execute_command_mock.return_value = cmd_mock - - # Mock the terminate_command (speed up test) - terminate_command_mock = mock.MagicMock() - monkeypatch.setattr( - "mlia.backend.proc.terminate_command", terminate_command_mock - ) - - # Mock a thrown Exception and assign to Command().exit_code - exit_code_mock = mock.PropertyMock(side_effect=Exception("Exception.")) - type(cmd_mock).exit_code = exit_code_mock - - with pytest.raises(Exception, match="Exception."): - execute_commands_locally( - ["command_1", "command_2"], cwd=test_resources_path - ) - - # Assert only "command_1" was executed - assert execute_command_mock.call_count == 1 - - -def test_reporter(tmpdir: Any) -> None: - """Test class 'Reporter'.""" - ctx = ExecutionContext( - app=get_application("application_4")[0], - app_params=["--app=TestApp"], - system=get_system("System 4"), - system_params=[], - ) - assert ctx.system is not None - - class MockParser(OutputParser): - """Mock implementation of an output parser.""" - - def __init__(self, metrics: Dict[str, Any]) -> None: - """Set up the MockParser.""" - super().__init__(name="test") - self.metrics = metrics - - def __call__(self, output: bytearray) -> Dict[str, Any]: - """Return mock metrics (ignoring the given output).""" - return self.metrics - - metrics = {"Metric": 123, "AnotherMetric": 456} - reporter = Reporter( - parsers=[MockParser(metrics={key: val}) for key, val in metrics.items()], - ) - reporter.parse(bytearray()) - report = reporter.report(ctx) - assert report["system"]["name"] == ctx.system.name - assert report["system"]["params"] == {} - assert report["application"]["name"] == ctx.app.name - assert report["application"]["params"] == {"--app": "TestApp"} - assert report["test"]["metrics"] == metrics - report_file = Path(tmpdir) / "report.json" - reporter.save(report, report_file) - assert report_file.is_file() diff --git a/tests/mlia/test_backend_output_parser.py b/tests/mlia/test_backend_output_parser.py deleted file mode 100644 index d86aac8..0000000 --- a/tests/mlia/test_backend_output_parser.py +++ /dev/null @@ -1,152 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tests for the output parsing.""" -import base64 -import json -from typing import Any -from typing import Dict - -import pytest - -from mlia.backend.output_parser import Base64OutputParser -from mlia.backend.output_parser import OutputParser -from mlia.backend.output_parser import RegexOutputParser - - -OUTPUT_MATCH_ALL = bytearray( - """ -String1: My awesome string! -String2: STRINGS_ARE_GREAT!!! -Int: 12 -Float: 3.14 -""", - encoding="utf-8", -) - -OUTPUT_NO_MATCH = bytearray( - """ -This contains no matches... -Test1234567890!"£$%^&*()_+@~{}[]/.,<>?| -""", - encoding="utf-8", -) - -OUTPUT_PARTIAL_MATCH = bytearray( - "String1: My awesome string!", - encoding="utf-8", -) - -REGEX_CONFIG = { - "FirstString": {"pattern": r"String1.*: (.*)", "type": "str"}, - "SecondString": {"pattern": r"String2.*: (.*)!!!", "type": "str"}, - "IntegerValue": {"pattern": r"Int.*: (.*)", "type": "int"}, - "FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"}, -} - -EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {} - -EXPECTED_METRICS_ALL = { - "FirstString": "My awesome string!", - "SecondString": "STRINGS_ARE_GREAT", - "IntegerValue": 12, - "FloatValue": 3.14, -} - -EXPECTED_METRICS_PARTIAL = { - "FirstString": "My awesome string!", -} - - -class TestRegexOutputParser: - """Collect tests for the RegexOutputParser.""" - - @staticmethod - @pytest.mark.parametrize( - ["output", "config", "expected_metrics"], - [ - (OUTPUT_MATCH_ALL, REGEX_CONFIG, EXPECTED_METRICS_ALL), - (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL), - (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL), - ( - OUTPUT_MATCH_ALL + OUTPUT_PARTIAL_MATCH, - REGEX_CONFIG, - EXPECTED_METRICS_ALL, - ), - (OUTPUT_NO_MATCH, REGEX_CONFIG, {}), - (OUTPUT_MATCH_ALL, EMPTY_REGEX_CONFIG, {}), - (bytearray(), EMPTY_REGEX_CONFIG, {}), - (bytearray(), REGEX_CONFIG, {}), - ], - ) - def test_parsing(output: bytearray, config: Dict, expected_metrics: Dict) -> None: - """ - Make sure the RegexOutputParser yields valid results. - - I.e. return an empty dict if either the input or the config is empty and - return the parsed metrics otherwise. - """ - parser = RegexOutputParser(name="Test", regex_config=config) - assert parser.name == "Test" - assert isinstance(parser, OutputParser) - res = parser(output) - assert res == expected_metrics - - @staticmethod - def test_unsupported_type() -> None: - """An unsupported type in the regex_config must raise an exception.""" - config = {"BrokenMetric": {"pattern": "(.*)", "type": "UNSUPPORTED_TYPE"}} - with pytest.raises(TypeError): - RegexOutputParser(name="Test", regex_config=config) - - @staticmethod - @pytest.mark.parametrize( - "config", - ( - {"TooManyGroups": {"pattern": r"(\w)(\d)", "type": "str"}}, - {"NoGroups": {"pattern": r"\W", "type": "str"}}, - ), - ) - def test_invalid_pattern(config: Dict) -> None: - """Exactly one capturing parenthesis is allowed in the regex pattern.""" - with pytest.raises(ValueError): - RegexOutputParser(name="Test", regex_config=config) - - -@pytest.mark.parametrize( - "expected_metrics", - [ - EXPECTED_METRICS_ALL, - EXPECTED_METRICS_PARTIAL, - ], -) -def test_base64_output_parser(expected_metrics: Dict) -> None: - """ - Make sure the Base64OutputParser yields valid results. - - I.e. return an empty dict if either the input or the config is empty and - return the parsed metrics otherwise. - """ - parser = Base64OutputParser(name="Test") - assert parser.name == "Test" - assert isinstance(parser, OutputParser) - - def create_base64_output(expected_metrics: Dict) -> bytearray: - json_str = json.dumps(expected_metrics, indent=4) - json_b64 = base64.b64encode(json_str.encode("utf-8")) - return ( - OUTPUT_MATCH_ALL # Should not be matched by the Base64OutputParser - + f"<{Base64OutputParser.TAG_NAME}>".encode("utf-8") - + bytearray(json_b64) - + f"</{Base64OutputParser.TAG_NAME}>".encode("utf-8") - + OUTPUT_NO_MATCH # Just to add some difficulty... - ) - - output = create_base64_output(expected_metrics) - res = parser(output) - assert len(res) == 1 - assert isinstance(res, dict) - for val in res.values(): - assert val == expected_metrics - - output = parser.filter_out_parsed_content(output) - assert output == (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH) diff --git a/tests/mlia/test_backend_protocol.py b/tests/mlia/test_backend_protocol.py deleted file mode 100644 index 35e9986..0000000 --- a/tests/mlia/test_backend_protocol.py +++ /dev/null @@ -1,231 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=no-self-use,attribute-defined-outside-init,protected-access -"""Tests for the protocol backend module.""" -from contextlib import ExitStack as does_not_raise -from pathlib import Path -from typing import Any -from unittest.mock import MagicMock - -import paramiko -import pytest - -from mlia.backend.common import ConfigurationException -from mlia.backend.config import LocalProtocolConfig -from mlia.backend.protocol import CustomSFTPClient -from mlia.backend.protocol import LocalProtocol -from mlia.backend.protocol import ProtocolFactory -from mlia.backend.protocol import SSHProtocol - - -class TestProtocolFactory: - """Test ProtocolFactory class.""" - - @pytest.mark.parametrize( - "config, expected_class, exception", - [ - ( - { - "protocol": "ssh", - "username": "user", - "password": "pass", - "hostname": "hostname", - "port": "22", - }, - SSHProtocol, - does_not_raise(), - ), - ({"protocol": "local"}, LocalProtocol, does_not_raise()), - ( - {"protocol": "something"}, - None, - pytest.raises(Exception, match="Protocol not supported"), - ), - (None, None, pytest.raises(Exception, match="No protocol config provided")), - ], - ) - def test_get_protocol( - self, config: Any, expected_class: type, exception: Any - ) -> None: - """Test get_protocol method.""" - factory = ProtocolFactory() - with exception: - protocol = factory.get_protocol(config) - assert isinstance(protocol, expected_class) - - -class TestLocalProtocol: - """Test local protocol.""" - - def test_local_protocol_run_command(self) -> None: - """Test local protocol run command.""" - config = LocalProtocolConfig(protocol="local") - protocol = LocalProtocol(config, cwd=Path("/tmp")) - ret, stdout, stderr = protocol.run("pwd") - assert ret == 0 - assert stdout.decode("utf-8").strip() == "/tmp" - assert stderr.decode("utf-8") == "" - - def test_local_protocol_run_wrong_cwd(self) -> None: - """Execution should fail if wrong working directory provided.""" - config = LocalProtocolConfig(protocol="local") - protocol = LocalProtocol(config, cwd=Path("unknown_directory")) - with pytest.raises( - ConfigurationException, match="Wrong working directory unknown_directory" - ): - protocol.run("pwd") - - -class TestSSHProtocol: - """Test SSH protocol.""" - - @pytest.fixture(autouse=True) - def setup_method(self, monkeypatch: Any) -> None: - """Set up protocol mocks.""" - self.mock_ssh_client = MagicMock(spec=paramiko.client.SSHClient) - - self.mock_ssh_channel = ( - self.mock_ssh_client.get_transport.return_value.open_session.return_value - ) - self.mock_ssh_channel.mock_add_spec(spec=paramiko.channel.Channel) - self.mock_ssh_channel.exit_status_ready.side_effect = [False, True] - self.mock_ssh_channel.recv_exit_status.return_value = True - self.mock_ssh_channel.recv_ready.side_effect = [False, True] - self.mock_ssh_channel.recv_stderr_ready.side_effect = [False, True] - - monkeypatch.setattr( - "mlia.backend.protocol.paramiko.client.SSHClient", - MagicMock(return_value=self.mock_ssh_client), - ) - - self.mock_sftp_client = MagicMock(spec=CustomSFTPClient) - monkeypatch.setattr( - "mlia.backend.protocol.CustomSFTPClient.from_transport", - MagicMock(return_value=self.mock_sftp_client), - ) - - ssh_config = { - "protocol": "ssh", - "username": "user", - "password": "pass", - "hostname": "hostname", - "port": "22", - } - self.protocol = SSHProtocol(ssh_config) - - def test_unable_create_ssh_client(self, monkeypatch: Any) -> None: - """Test that command should fail if unable to create ssh client instance.""" - monkeypatch.setattr( - "mlia.backend.protocol.paramiko.client.SSHClient", - MagicMock(side_effect=OSError("Error!")), - ) - - with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"): - self.protocol.run("command_example", retry=False) - - def test_ssh_protocol_run_command(self) -> None: - """Test that command run via ssh successfully.""" - self.protocol.run("command_example") - self.mock_ssh_channel.exec_command.assert_called_once() - - def test_ssh_protocol_run_command_connect_failed(self) -> None: - """Test that if connection is not possible then correct exception is raised.""" - self.mock_ssh_client.connect.side_effect = OSError("Unable to connect") - self.mock_ssh_client.close.side_effect = Exception("Error!") - - with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"): - self.protocol.run("command_example", retry=False) - - def test_ssh_protocol_run_command_bad_transport(self) -> None: - """Test that command should fail if unable to get transport.""" - self.mock_ssh_client.get_transport.return_value = None - - with pytest.raises(Exception, match="Unable to get transport"): - self.protocol.run("command_example", retry=False) - - def test_ssh_protocol_deploy_command_file( - self, test_applications_path: Path - ) -> None: - """Test that files could be deployed over ssh.""" - file_for_deploy = test_applications_path / "readme.txt" - dest = "/tmp/dest" - - self.protocol.deploy(file_for_deploy, dest) - self.mock_sftp_client.put.assert_called_once_with(str(file_for_deploy), dest) - - def test_ssh_protocol_deploy_command_unknown_file(self) -> None: - """Test that deploy will fail if file does not exist.""" - with pytest.raises(Exception, match="Deploy error: file type not supported"): - self.protocol.deploy(Path("unknown_file"), "/tmp/dest") - - def test_ssh_protocol_deploy_command_bad_transport(self) -> None: - """Test that deploy should fail if unable to get transport.""" - self.mock_ssh_client.get_transport.return_value = None - - with pytest.raises(Exception, match="Unable to get transport"): - self.protocol.deploy(Path("some_file"), "/tmp/dest") - - def test_ssh_protocol_deploy_command_directory( - self, test_resources_path: Path - ) -> None: - """Test that directory could be deployed over ssh.""" - directory_for_deploy = test_resources_path / "scripts" - dest = "/tmp/dest" - - self.protocol.deploy(directory_for_deploy, dest) - self.mock_sftp_client.put_dir.assert_called_once_with( - directory_for_deploy, dest - ) - - @pytest.mark.parametrize("establish_connection", (True, False)) - def test_ssh_protocol_close(self, establish_connection: bool) -> None: - """Test protocol close operation.""" - if establish_connection: - self.protocol.establish_connection() - self.protocol.close() - - call_count = 1 if establish_connection else 0 - assert self.mock_ssh_channel.exec_command.call_count == call_count - - def test_connection_details(self) -> None: - """Test getting connection details.""" - assert self.protocol.connection_details() == ("hostname", 22) - - -class TestCustomSFTPClient: - """Test CustomSFTPClient class.""" - - @pytest.fixture(autouse=True) - def setup_method(self, monkeypatch: Any) -> None: - """Set up mocks for CustomSFTPClient instance.""" - self.mock_mkdir = MagicMock() - self.mock_put = MagicMock() - monkeypatch.setattr( - "mlia.backend.protocol.paramiko.SFTPClient.__init__", - MagicMock(return_value=None), - ) - monkeypatch.setattr( - "mlia.backend.protocol.paramiko.SFTPClient.mkdir", self.mock_mkdir - ) - monkeypatch.setattr( - "mlia.backend.protocol.paramiko.SFTPClient.put", self.mock_put - ) - - self.sftp_client = CustomSFTPClient(MagicMock()) - - def test_put_dir(self, test_systems_path: Path) -> None: - """Test deploying directory to remote host.""" - directory_for_deploy = test_systems_path / "system1" - - self.sftp_client.put_dir(directory_for_deploy, "/tmp/dest") - assert self.mock_put.call_count == 3 - assert self.mock_mkdir.call_count == 3 - - def test_mkdir(self) -> None: - """Test creating directory on remote host.""" - self.mock_mkdir.side_effect = IOError("Cannot create directory") - - self.sftp_client._mkdir("new_directory", ignore_existing=True) - - with pytest.raises(IOError, match="Cannot create directory"): - self.sftp_client._mkdir("new_directory", ignore_existing=False) diff --git a/tests/mlia/test_utils_proc.py b/tests/mlia/test_utils_proc.py deleted file mode 100644 index 8316ca5..0000000 --- a/tests/mlia/test_utils_proc.py +++ /dev/null @@ -1,149 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tests for the module utils/proc.""" -import signal -import subprocess -import time -from pathlib import Path -from unittest.mock import MagicMock - -import pytest - -from mlia.utils.proc import CommandExecutor -from mlia.utils.proc import working_directory - - -class TestCommandExecutor: - """Tests for class CommandExecutor.""" - - @staticmethod - def test_execute() -> None: - """Test command execution.""" - executor = CommandExecutor() - - retcode, stdout, stderr = executor.execute(["echo", "hello world!"]) - assert retcode == 0 - assert stdout.decode().strip() == "hello world!" - assert stderr.decode() == "" - - @staticmethod - def test_submit() -> None: - """Test command submittion.""" - executor = CommandExecutor() - - running_command = executor.submit(["sleep", "10"]) - assert running_command.is_alive() is True - assert running_command.exit_code() is None - - running_command.kill() - for _ in range(3): - time.sleep(0.5) - if not running_command.is_alive(): - break - - assert running_command.is_alive() is False - assert running_command.exit_code() == -9 - - with pytest.raises(subprocess.CalledProcessError): - executor.execute(["sleep", "-1"]) - - @staticmethod - @pytest.mark.parametrize("wait", [True, False]) - def test_stop(wait: bool) -> None: - """Test command termination.""" - executor = CommandExecutor() - - running_command = executor.submit(["sleep", "10"]) - running_command.stop(wait=wait) - - if wait: - assert running_command.is_alive() is False - - @staticmethod - def test_unable_to_stop(monkeypatch: pytest.MonkeyPatch) -> None: - """Test when command could not be stopped.""" - running_command_mock = MagicMock() - running_command_mock.poll.return_value = None - - monkeypatch.setattr( - "mlia.utils.proc.subprocess.Popen", - MagicMock(return_value=running_command_mock), - ) - - with pytest.raises(Exception, match="Unable to stop running command"): - executor = CommandExecutor() - running_command = executor.submit(["sleep", "10"]) - - running_command.stop(num_of_attempts=1, interval=0.1) - - running_command_mock.send_signal.assert_called_once_with(signal.SIGINT) - - @staticmethod - def test_stop_after_several_attempts(monkeypatch: pytest.MonkeyPatch) -> None: - """Test when command could be stopped after several attempts.""" - running_command_mock = MagicMock() - running_command_mock.poll.side_effect = [None, 0] - - monkeypatch.setattr( - "mlia.utils.proc.subprocess.Popen", - MagicMock(return_value=running_command_mock), - ) - - executor = CommandExecutor() - running_command = executor.submit(["sleep", "10"]) - - running_command.stop(num_of_attempts=1, interval=0.1) - running_command_mock.send_signal.assert_called_once_with(signal.SIGINT) - - @staticmethod - def test_send_signal() -> None: - """Test sending signal.""" - executor = CommandExecutor() - running_command = executor.submit(["sleep", "10"]) - running_command.send_signal(signal.SIGINT) - - # wait a bit for a signal processing - time.sleep(1) - - assert running_command.is_alive() is False - assert running_command.exit_code() == -2 - - @staticmethod - @pytest.mark.parametrize( - "redirect_output, expected_output", [[True, "hello\n"], [False, ""]] - ) - def test_wait( - capsys: pytest.CaptureFixture, redirect_output: bool, expected_output: str - ) -> None: - """Test wait completion functionality.""" - executor = CommandExecutor() - - running_command = executor.submit(["echo", "hello"]) - running_command.wait(redirect_output=redirect_output) - - out, _ = capsys.readouterr() - assert out == expected_output - - -@pytest.mark.parametrize( - "should_exist, create_dir", - [ - [True, False], - [False, True], - ], -) -def test_working_directory_context_manager( - tmp_path: Path, should_exist: bool, create_dir: bool -) -> None: - """Test working_directory context manager.""" - prev_wd = Path.cwd() - - working_dir = tmp_path / "work_dir" - if should_exist: - working_dir.mkdir() - - with working_directory(working_dir, create_dir=create_dir) as current_working_dir: - assert current_working_dir.is_dir() - assert Path.cwd() == current_working_dir - - assert Path.cwd() == prev_wd diff --git a/tests/mlia/test_api.py b/tests/test_api.py index 09bc509..09bc509 100644 --- a/tests/mlia/test_api.py +++ b/tests/test_api.py diff --git a/tests/mlia/test_backend_application.py b/tests/test_backend_application.py index 2cfb2ef..6860ecb 100644 --- a/tests/mlia/test_backend_application.py +++ b/tests/test_backend_application.py @@ -20,7 +20,6 @@ from mlia.backend.application import install_application from mlia.backend.application import load_applications from mlia.backend.application import remove_application from mlia.backend.common import Command -from mlia.backend.common import DataPaths from mlia.backend.common import Param from mlia.backend.common import UserParamConfig from mlia.backend.config import ApplicationConfig @@ -186,7 +185,6 @@ class TestApplication: config = ApplicationConfig( # Application supported_systems=["system1", "system2"], - build_dir="build_dir", # inherited from Backend name="name", description="description", @@ -222,24 +220,6 @@ class TestApplication: assert len(applications) == 1 assert applications[0].can_run_on("System 1") - def test_get_deploy_data(self, tmp_path: Path) -> None: - """Test Application can run on.""" - src, dest = "src", "dest" - config = ApplicationConfig( - name="application", deploy_data=[(src, dest)], config_location=tmp_path - ) - src_path = tmp_path / src - src_path.mkdir() - application = Application(config) - assert application.get_deploy_data() == [DataPaths(src_path, dest)] - - def test_get_deploy_data_no_config_location(self) -> None: - """Test that getting deploy data fails if no config location provided.""" - with pytest.raises( - Exception, match="Unable to get application .* config location" - ): - Application(ApplicationConfig(name="application")).get_deploy_data() - def test_unable_to_create_application_without_name(self) -> None: """Test that it is not possible to create application without name.""" with pytest.raises(Exception, match="Name is empty"): @@ -394,14 +374,11 @@ def test_load_application() -> None: default_variables = {"var1": "value1", "var2": "value2"} application_5_0 = application_5[0] - assert application_5_0.build_dir == "default_build_dir" assert application_5_0.supported_systems == ["System 1"] assert application_5_0.commands == default_commands assert application_5_0.variables == default_variables - assert application_5_0.lock is False application_5_1 = application_5[1] - assert application_5_1.build_dir == application_5_0.build_dir assert application_5_1.supported_systems == ["System 2"] assert application_5_1.commands == application_5_1.commands assert application_5_1.variables == default_variables @@ -411,50 +388,31 @@ def test_load_application() -> None: application_5a_0 = application_5a[0] assert application_5a_0.supported_systems == ["System 1"] - assert application_5a_0.build_dir == "build_5A" assert application_5a_0.commands == default_commands assert application_5a_0.variables == {"var1": "new value1", "var2": "value2"} - assert application_5a_0.lock is False application_5a_1 = application_5a[1] assert application_5a_1.supported_systems == ["System 2"] - assert application_5a_1.build_dir == "build" assert application_5a_1.commands == { - "build": Command(["default build command"]), + "build": default_commands["build"], "run": Command(["run command on system 2"]), } assert application_5a_1.variables == {"var1": "value1", "var2": "new value2"} - assert application_5a_1.lock is True application_5b = get_application("application_5B") assert len(application_5b) == 2 application_5b_0 = application_5b[0] - assert application_5b_0.build_dir == "build_5B" assert application_5b_0.supported_systems == ["System 1"] assert application_5b_0.commands == { - "build": Command(["default build command with value for var1 System1"], []), + "build": Command(["default build command with value for var1 System1"]), "run": Command(["default run command with value for var2 System1"]), } assert "non_used_command" not in application_5b_0.commands application_5b_1 = application_5b[1] - assert application_5b_1.build_dir == "build" assert application_5b_1.supported_systems == ["System 2"] assert application_5b_1.commands == { - "build": Command( - [ - "build command on system 2 with value" - " for var1 System2 {user_params:param1}" - ], - [ - Param( - "--param", - "Sample command param", - ["value1", "value2", "value3"], - "value1", - ) - ], - ), + "build": Command(["default build command with value for var1 System2"]), "run": Command(["run command on system 2"], []), } diff --git a/tests/mlia/test_backend_common.py b/tests/test_backend_common.py index 82a985a..0533ef6 100644 --- a/tests/mlia/test_backend_common.py +++ b/tests/test_backend_common.py @@ -149,7 +149,7 @@ class TestBackend: application.validate_parameter("foo", "bar") assert "Unknown command: 'foo'" in str(err.value) - def test_build_command(self, monkeypatch: Any) -> None: + def test_build_command(self) -> None: """Test command building.""" config = { "name": "test", @@ -175,14 +175,12 @@ class TestBackend: "variables": {"var_A": "value for variable A"}, } - monkeypatch.setattr("mlia.backend.system.ProtocolFactory", MagicMock()) application, system = Application(config), System(config) # type: ignore context = ExecutionContext( app=application, app_params=[], system=system, system_params=[], - custom_deploy_data=[], ) param_resolver = ParamResolver(context) @@ -284,13 +282,11 @@ class TestBackend: ) def test_resolved_parameters( self, - monkeypatch: Any, class_: type, config: Dict, expected_output: List[Tuple[Optional[str], Param]], ) -> None: """Test command building.""" - monkeypatch.setattr("mlia.backend.system.ProtocolFactory", MagicMock()) backend = class_(config) params = backend.resolved_parameters( @@ -438,7 +434,7 @@ class TestCommand: Param("param", "param description", [], None, "alias"), Param("param", "param description", [], None, "alias"), ], - pytest.raises(ConfigurationException, match="Non unique aliases alias"), + pytest.raises(ConfigurationException, match="Non-unique aliases alias"), ], [ [ @@ -475,7 +471,7 @@ class TestCommand: Param("param4", "param4 description", [], None, "alias2"), ], pytest.raises( - ConfigurationException, match="Non unique aliases alias1, alias2" + ConfigurationException, match="Non-unique aliases alias1, alias2" ), ], ], diff --git a/tests/test_backend_execution.py b/tests/test_backend_execution.py new file mode 100644 index 0000000..07e7c98 --- /dev/null +++ b/tests/test_backend_execution.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Test backend execution module.""" +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from mlia.backend.application import Application +from mlia.backend.common import UserParamConfig +from mlia.backend.config import ApplicationConfig +from mlia.backend.config import SystemConfig +from mlia.backend.execution import ExecutionContext +from mlia.backend.execution import get_application_and_system +from mlia.backend.execution import get_application_by_name_and_system +from mlia.backend.execution import ParamResolver +from mlia.backend.system import load_system + + +def test_context_param_resolver(tmpdir: Any) -> None: + """Test parameter resolving.""" + system_config_location = Path(tmpdir) / "system" + system_config_location.mkdir() + + application_config_location = Path(tmpdir) / "application" + application_config_location.mkdir() + + ctx = ExecutionContext( + app=Application( + ApplicationConfig( + name="test_application", + description="Test application", + config_location=application_config_location, + commands={ + "run": [ + "run_command1 {user_params:0}", + "run_command2 {user_params:1}", + ] + }, + variables={"var_1": "value for var_1"}, + user_params={ + "run": [ + UserParamConfig( + name="--param1", + description="Param 1", + default_value="123", + alias="param_1", + ), + UserParamConfig( + name="--param2", description="Param 2", default_value="456" + ), + UserParamConfig( + name="--param3", description="Param 3", alias="param_3" + ), + UserParamConfig( + name="--param4=", + description="Param 4", + default_value="456", + alias="param_4", + ), + UserParamConfig( + description="Param 5", + default_value="789", + alias="param_5", + ), + ] + }, + ) + ), + app_params=["--param2=789"], + system=load_system( + SystemConfig( + name="test_system", + description="Test system", + config_location=system_config_location, + commands={ + "build": ["build_command1 {user_params:0}"], + "run": ["run_command {application.commands.run:1}"], + }, + variables={"var_1": "value for var_1"}, + user_params={ + "build": [ + UserParamConfig( + name="--param1", description="Param 1", default_value="aaa" + ), + UserParamConfig(name="--param2", description="Param 2"), + ] + }, + ) + ), + system_params=["--param1=bbb"], + ) + + param_resolver = ParamResolver(ctx) + expected_values = { + "application.name": "test_application", + "application.description": "Test application", + "application.config_dir": str(application_config_location), + "application.commands.run:0": "run_command1 --param1 123", + "application.commands.run.params:0": "123", + "application.commands.run.params:param_1": "123", + "application.commands.run:1": "run_command2 --param2 789", + "application.commands.run.params:1": "789", + "application.variables:var_1": "value for var_1", + "system.name": "test_system", + "system.description": "Test system", + "system.config_dir": str(system_config_location), + "system.commands.build:0": "build_command1 --param1 bbb", + "system.commands.run:0": "run_command run_command2 --param2 789", + "system.commands.build.params:0": "bbb", + "system.variables:var_1": "value for var_1", + } + + for param, value in expected_values.items(): + assert param_resolver(param) == value + + expected_errors = { + "application.variables:var_2": pytest.raises( + Exception, match="Unknown variable var_2" + ), + "application.commands.clean:0": pytest.raises( + Exception, match="Command clean not found" + ), + "application.commands.run:2": pytest.raises( + Exception, match="Invalid index 2 for command run" + ), + "application.commands.run.params:5": pytest.raises( + Exception, match="Invalid parameter index 5 for command run" + ), + "application.commands.run.params:param_2": pytest.raises( + Exception, + match="No value for parameter with index or alias param_2 of command run", + ), + "UNKNOWN": pytest.raises( + Exception, match="Unable to resolve parameter UNKNOWN" + ), + "system.commands.build.params:1": pytest.raises( + Exception, + match="No value for parameter with index or alias 1 of command build", + ), + "system.commands.build:A": pytest.raises( + Exception, match="Bad command index A" + ), + "system.variables:var_2": pytest.raises( + Exception, match="Unknown variable var_2" + ), + } + for param, error in expected_errors.items(): + with error: + param_resolver(param) + + resolved_params = ctx.app.resolved_parameters("run", []) + expected_user_params = { + "user_params:0": "--param1 123", + "user_params:param_1": "--param1 123", + "user_params:2": "--param3", + "user_params:param_3": "--param3", + "user_params:3": "--param4=456", + "user_params:param_4": "--param4=456", + "user_params:param_5": "789", + } + for param, expected_value in expected_user_params.items(): + assert param_resolver(param, "run", resolved_params) == expected_value + + with pytest.raises( + Exception, match="Invalid index 5 for user params of command run" + ): + param_resolver("user_params:5", "run", resolved_params) + + with pytest.raises( + Exception, match="No user parameter for command 'run' with alias 'param_2'." + ): + param_resolver("user_params:param_2", "run", resolved_params) + + with pytest.raises(Exception, match="Unable to resolve user params"): + param_resolver("user_params:0", "", resolved_params) + + +def test_get_application_by_name_and_system(monkeypatch: Any) -> None: + """Test exceptional case for get_application_by_name_and_system.""" + monkeypatch.setattr( + "mlia.backend.execution.get_application", + MagicMock(return_value=[MagicMock(), MagicMock()]), + ) + + with pytest.raises( + ValueError, + match="Error during getting application test_application for the " + "system test_system", + ): + get_application_by_name_and_system("test_application", "test_system") + + +def test_get_application_and_system(monkeypatch: Any) -> None: + """Test exceptional case for get_application_and_system.""" + monkeypatch.setattr( + "mlia.backend.execution.get_system", MagicMock(return_value=None) + ) + + with pytest.raises(ValueError, match="System test_system is not found"): + get_application_and_system("test_application", "test_system") diff --git a/tests/mlia/test_backend_fs.py b/tests/test_backend_fs.py index ff9c2ae..7423222 100644 --- a/tests/mlia/test_backend_fs.py +++ b/tests/test_backend_fs.py @@ -11,8 +11,6 @@ from unittest.mock import MagicMock import pytest from mlia.backend.fs import get_backends_path -from mlia.backend.fs import read_file_as_bytearray -from mlia.backend.fs import read_file_as_string from mlia.backend.fs import recreate_directory from mlia.backend.fs import remove_directory from mlia.backend.fs import remove_resource @@ -120,38 +118,6 @@ def write_to_file( return tmpfile -class TestReadFileAsString: - """Test read_file_as_string() function.""" - - def test_returns_text_from_valid_file(self, tmpdir: Any) -> None: - """Ensure the string written to a file read correctly.""" - file_path = write_to_file(tmpdir, "w", "hello") - assert read_file_as_string(file_path) == "hello" - - def test_output_is_empty_string_when_input_file_non_existent( - self, tmpdir: Any - ) -> None: - """Ensure empty string returned when reading from non-existent file.""" - file_path = Path(tmpdir / "non-existent.txt") - assert read_file_as_string(file_path) == "" - - -class TestReadFileAsByteArray: - """Test read_file_as_bytearray() function.""" - - def test_returns_bytes_from_valid_file(self, tmpdir: Any) -> None: - """Ensure the bytes written to a file read correctly.""" - file_path = write_to_file(tmpdir, "wb", b"hello bytes") - assert read_file_as_bytearray(file_path) == b"hello bytes" - - def test_output_is_empty_bytearray_when_input_file_non_existent( - self, tmpdir: Any - ) -> None: - """Ensure empty bytearray returned when reading from non-existent file.""" - file_path = Path(tmpdir / "non-existent.txt") - assert read_file_as_bytearray(file_path) == bytearray() - - @pytest.mark.parametrize( "value, replacement, expected_result", [ diff --git a/tests/mlia/test_backend_manager.py b/tests/test_backend_manager.py index c81366f..1b5fea1 100644 --- a/tests/mlia/test_backend_manager.py +++ b/tests/test_backend_manager.py @@ -1,13 +1,15 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module backend/manager.""" -import os +import base64 +import json from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any from typing import Dict from typing import List from typing import Optional +from typing import Set from typing import Tuple from unittest.mock import MagicMock from unittest.mock import PropertyMock @@ -15,7 +17,6 @@ from unittest.mock import PropertyMock import pytest from mlia.backend.application import get_application -from mlia.backend.common import DataPaths from mlia.backend.execution import ExecutionContext from mlia.backend.execution import run_application from mlia.backend.manager import BackendRunner @@ -30,9 +31,45 @@ from mlia.backend.manager import is_supported from mlia.backend.manager import ModelInfo from mlia.backend.manager import PerformanceMetrics from mlia.backend.manager import supported_backends +from mlia.backend.output_consumer import Base64OutputConsumer from mlia.backend.system import get_system +def _mock_encode_b64(data: Dict[str, int]) -> str: + """ + Encode the given data into a mock base64-encoded string of JSON. + + This reproduces the base64 encoding done in the Corstone applications. + + JSON example: + + ```json + [{'count': 1, + 'profiling_group': 'Inference', + 'samples': [{'name': 'NPU IDLE', 'value': [612]}, + {'name': 'NPU AXI0_RD_DATA_BEAT_RECEIVED', 'value': [165872]}, + {'name': 'NPU AXI0_WR_DATA_BEAT_WRITTEN', 'value': [88712]}, + {'name': 'NPU AXI1_RD_DATA_BEAT_RECEIVED', 'value': [57540]}, + {'name': 'NPU ACTIVE', 'value': [520489]}, + {'name': 'NPU TOTAL', 'value': [521101]}]}] + ``` + """ + wrapped_data = [ + { + "count": 1, + "profiling_group": "Inference", + "samples": [ + {"name": name, "value": [value]} for name, value in data.items() + ], + } + ] + json_str = json.dumps(wrapped_data) + json_bytes = bytearray(json_str, encoding="utf-8") + json_b64 = base64.b64encode(json_bytes).decode("utf-8") + tag = Base64OutputConsumer.TAG_NAME + return f"<{tag}>{json_b64}</{tag}>" + + @pytest.mark.parametrize( "data, is_ready, result, missed_keys", [ @@ -40,50 +77,52 @@ from mlia.backend.system import get_system [], False, {}, - [ + { "npu_active_cycles", "npu_axi0_rd_data_beat_received", "npu_axi0_wr_data_beat_written", "npu_axi1_rd_data_beat_received", "npu_idle_cycles", "npu_total_cycles", - ], + }, ), ( ["sample text"], False, {}, - [ + { "npu_active_cycles", "npu_axi0_rd_data_beat_received", "npu_axi0_wr_data_beat_written", "npu_axi1_rd_data_beat_received", "npu_idle_cycles", "npu_total_cycles", - ], + }, ), ( - [ - ["NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 123"], - False, - {"npu_axi0_rd_data_beat_received": 123}, - [ - "npu_active_cycles", - "npu_axi0_wr_data_beat_written", - "npu_axi1_rd_data_beat_received", - "npu_idle_cycles", - "npu_total_cycles", - ], - ] + [_mock_encode_b64({"NPU AXI0_RD_DATA_BEAT_RECEIVED": 123})], + False, + {"npu_axi0_rd_data_beat_received": 123}, + { + "npu_active_cycles", + "npu_axi0_wr_data_beat_written", + "npu_axi1_rd_data_beat_received", + "npu_idle_cycles", + "npu_total_cycles", + }, ), ( [ - "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1", - "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2", - "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3", - "NPU ACTIVE cycles: 4", - "NPU IDLE cycles: 5", - "NPU TOTAL cycles: 6", + _mock_encode_b64( + { + "NPU AXI0_RD_DATA_BEAT_RECEIVED": 1, + "NPU AXI0_WR_DATA_BEAT_WRITTEN": 2, + "NPU AXI1_RD_DATA_BEAT_RECEIVED": 3, + "NPU ACTIVE": 4, + "NPU IDLE": 5, + "NPU TOTAL": 6, + } + ) ], True, { @@ -94,12 +133,12 @@ from mlia.backend.system import get_system "npu_idle_cycles": 5, "npu_total_cycles": 6, }, - [], + set(), ), ], ) def test_generic_inference_output_parser( - data: List[str], is_ready: bool, result: Dict, missed_keys: List[str] + data: Dict[str, int], is_ready: bool, result: Dict, missed_keys: Set[str] ) -> None: """Test generic runner output parser.""" parser = GenericInferenceOutputParser() @@ -316,8 +355,8 @@ class TestBackendRunner: "execution_params, expected_command", [ ( - ExecutionParams("application_4", "System 4", [], [], []), - ["application_4", [], "System 4", [], []], + ExecutionParams("application_4", "System 4", [], []), + ["application_4", [], "System 4", []], ), ( ExecutionParams( @@ -325,14 +364,12 @@ class TestBackendRunner: "System 6", ["param1=value2"], ["sys-param1=value2"], - [], ), [ "application_6", ["param1=value2"], "System 6", ["sys-param1=value2"], - [], ], ), ], @@ -351,67 +388,6 @@ class TestBackendRunner: run_app.assert_called_once_with(*expected_command) - @staticmethod - @pytest.mark.parametrize( - "execution_params, expected_command", - [ - ( - ExecutionParams( - "application_1", - "System 1", - [], - [], - ["source1.txt:dest1.txt", "source2.txt:dest2.txt"], - ), - [ - "application_1", - [], - "System 1", - [], - [ - DataPaths(Path("source1.txt"), "dest1.txt"), - DataPaths(Path("source2.txt"), "dest2.txt"), - ], - ], - ), - ], - ) - def test_run_application_connected( - monkeypatch: pytest.MonkeyPatch, - execution_params: ExecutionParams, - expected_command: List[str], - ) -> None: - """Test method run_application with connectable systems (SSH).""" - app = get_application(execution_params.application, execution_params.system)[0] - sys = get_system(execution_params.system) - - assert sys is not None - - connect_mock = MagicMock(return_value=True, name="connect_mock") - deploy_mock = MagicMock(return_value=True, name="deploy_mock") - run_mock = MagicMock( - return_value=(os.EX_OK, bytearray(), bytearray()), name="run_mock" - ) - sys.establish_connection = connect_mock # type: ignore - sys.deploy = deploy_mock # type: ignore - sys.run = run_mock # type: ignore - - monkeypatch.setattr( - "mlia.backend.execution.get_application_and_system", - MagicMock(return_value=(app, sys)), - ) - - run_app_mock = MagicMock(wraps=run_application) - monkeypatch.setattr("mlia.backend.manager.run_application", run_app_mock) - - backend_runner = BackendRunner() - backend_runner.run_application(execution_params) - - run_app_mock.assert_called_once_with(*expected_command) - - connect_mock.assert_called_once() - assert deploy_mock.call_count == 2 - @pytest.mark.parametrize( "device, system, application, backend, expected_error", @@ -531,12 +507,16 @@ def test_estimate_performance( mock_context = create_mock_context( [ - "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1", - "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2", - "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3", - "NPU ACTIVE cycles: 4", - "NPU IDLE cycles: 5", - "NPU TOTAL cycles: 6", + _mock_encode_b64( + { + "NPU AXI0_RD_DATA_BEAT_RECEIVED": 1, + "NPU AXI0_WR_DATA_BEAT_WRITTEN": 2, + "NPU AXI1_RD_DATA_BEAT_RECEIVED": 3, + "NPU ACTIVE": 4, + "NPU IDLE": 5, + "NPU TOTAL": 6, + } + ) ] ) @@ -571,14 +551,14 @@ def test_estimate_performance_insufficient_data( backend_runner.is_system_installed.return_value = True backend_runner.is_application_installed.return_value = True - no_total_cycles_output = [ - "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1", - "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2", - "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3", - "NPU ACTIVE cycles: 4", - "NPU IDLE cycles: 5", - ] - mock_context = create_mock_context(no_total_cycles_output) + no_total_cycles_output = { + "NPU AXI0_RD_DATA_BEAT_RECEIVED": 1, + "NPU AXI0_WR_DATA_BEAT_WRITTEN": 2, + "NPU AXI1_RD_DATA_BEAT_RECEIVED": 3, + "NPU ACTIVE": 4, + "NPU IDLE": 5, + } + mock_context = create_mock_context([_mock_encode_b64(no_total_cycles_output)]) backend_runner.run_application.return_value = mock_context diff --git a/tests/test_backend_output_consumer.py b/tests/test_backend_output_consumer.py new file mode 100644 index 0000000..881112e --- /dev/null +++ b/tests/test_backend_output_consumer.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the output parsing.""" +import base64 +import json +from typing import Any +from typing import Dict + +import pytest + +from mlia.backend.output_consumer import Base64OutputConsumer +from mlia.backend.output_consumer import OutputConsumer + + +OUTPUT_MATCH_ALL = bytearray( + """ +String1: My awesome string! +String2: STRINGS_ARE_GREAT!!! +Int: 12 +Float: 3.14 +""", + encoding="utf-8", +) + +OUTPUT_NO_MATCH = bytearray( + """ +This contains no matches... +Test1234567890!"£$%^&*()_+@~{}[]/.,<>?| +""", + encoding="utf-8", +) + +OUTPUT_PARTIAL_MATCH = bytearray( + "String1: My awesome string!", + encoding="utf-8", +) + +REGEX_CONFIG = { + "FirstString": {"pattern": r"String1.*: (.*)", "type": "str"}, + "SecondString": {"pattern": r"String2.*: (.*)!!!", "type": "str"}, + "IntegerValue": {"pattern": r"Int.*: (.*)", "type": "int"}, + "FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"}, +} + +EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {} + +EXPECTED_METRICS_ALL = { + "FirstString": "My awesome string!", + "SecondString": "STRINGS_ARE_GREAT", + "IntegerValue": 12, + "FloatValue": 3.14, +} + +EXPECTED_METRICS_PARTIAL = { + "FirstString": "My awesome string!", +} + + +@pytest.mark.parametrize( + "expected_metrics", + [ + EXPECTED_METRICS_ALL, + EXPECTED_METRICS_PARTIAL, + ], +) +def test_base64_output_consumer(expected_metrics: Dict) -> None: + """ + Make sure the Base64OutputConsumer yields valid results. + + I.e. return an empty dict if either the input or the config is empty and + return the parsed metrics otherwise. + """ + parser = Base64OutputConsumer() + assert isinstance(parser, OutputConsumer) + + def create_base64_output(expected_metrics: Dict) -> bytearray: + json_str = json.dumps(expected_metrics, indent=4) + json_b64 = base64.b64encode(json_str.encode("utf-8")) + return ( + OUTPUT_MATCH_ALL # Should not be matched by the Base64OutputConsumer + + f"<{Base64OutputConsumer.TAG_NAME}>".encode("utf-8") + + bytearray(json_b64) + + f"</{Base64OutputConsumer.TAG_NAME}>".encode("utf-8") + + OUTPUT_NO_MATCH # Just to add some difficulty... + ) + + output = create_base64_output(expected_metrics) + + consumed = False + for line in output.splitlines(): + if parser.feed(line.decode("utf-8")): + consumed = True + assert consumed # we should have consumed at least one line + + res = parser.parsed_output + assert len(res) == 1 + assert isinstance(res, list) + for val in res: + assert val == expected_metrics diff --git a/tests/mlia/test_backend_proc.py b/tests/test_backend_proc.py index 9ca4788..f47c244 100644 --- a/tests/mlia/test_backend_proc.py +++ b/tests/test_backend_proc.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Any from unittest import mock -import psutil import pytest from sh import ErrorReturnCode @@ -16,10 +15,8 @@ from mlia.backend.proc import CommandNotFound from mlia.backend.proc import parse_command from mlia.backend.proc import print_command_stdout from mlia.backend.proc import run_and_wait -from mlia.backend.proc import save_process_info from mlia.backend.proc import ShellCommand from mlia.backend.proc import terminate_command -from mlia.backend.proc import terminate_external_process class TestShellCommand: @@ -109,61 +106,6 @@ def test_print_command_stdout_not_alive(mock_print: Any) -> None: mock_print.assert_called_once_with("test") -def test_terminate_external_process_no_process(capsys: Any) -> None: - """Test that non existed process could be terminated.""" - mock_command = mock.MagicMock() - mock_command.terminate.side_effect = psutil.Error("Error!") - - terminate_external_process(mock_command) - captured = capsys.readouterr() - assert captured.out == "Unable to terminate process\n" - - -def test_terminate_external_process_case1() -> None: - """Test when process terminated immediately.""" - mock_command = mock.MagicMock() - mock_command.is_running.return_value = False - - terminate_external_process(mock_command) - mock_command.terminate.assert_called_once() - mock_command.is_running.assert_called_once() - - -def test_terminate_external_process_case2() -> None: - """Test when process termination takes time.""" - mock_command = mock.MagicMock() - mock_command.is_running.side_effect = [True, True, False] - - terminate_external_process(mock_command) - mock_command.terminate.assert_called_once() - assert mock_command.is_running.call_count == 3 - - -def test_terminate_external_process_case3() -> None: - """Test when process termination takes more time.""" - mock_command = mock.MagicMock() - mock_command.is_running.side_effect = [True, True, True] - - terminate_external_process( - mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1 - ) - assert mock_command.is_running.call_count == 3 - assert mock_command.terminate.call_count == 2 - - -def test_terminate_external_process_case4() -> None: - """Test when process termination takes more time.""" - mock_command = mock.MagicMock() - mock_command.is_running.side_effect = [True, True, False] - - terminate_external_process( - mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1 - ) - mock_command.terminate.assert_called_once() - assert mock_command.is_running.call_count == 3 - assert mock_command.terminate.call_count == 1 - - def test_terminate_command_no_process() -> None: """Test command termination when process does not exist.""" mock_command = mock.MagicMock() @@ -253,17 +195,6 @@ class TestRunAndWait: assert self.terminate_command_mock.call_count == call_count -def test_save_process_info_no_process(monkeypatch: Any, tmpdir: Any) -> None: - """Test save_process_info function.""" - mock_process = mock.MagicMock() - monkeypatch.setattr("psutil.Process", mock.MagicMock(return_value=mock_process)) - mock_process.children.side_effect = psutil.NoSuchProcess(555) - - pid_file_path = Path(tmpdir) / "test.pid" - save_process_info(555, pid_file_path) - assert not pid_file_path.exists() - - def test_parse_command() -> None: """Test parse_command function.""" assert parse_command("1.sh") == ["bash", "1.sh"] diff --git a/tests/mlia/test_backend_source.py b/tests/test_backend_source.py index 84a6a77..11f1781 100644 --- a/tests/mlia/test_backend_source.py +++ b/tests/test_backend_source.py @@ -160,8 +160,8 @@ class TestTarArchiveSource: tmpdir_path = Path(tmpdir) tar_source.install_into(tmpdir_path) source_files = [ - "aiet-config.json.license", - "aiet-config.json", + "backend-config.json.license", + "backend-config.json", "system_artifact", ] dest_files = [f.name for f in tmpdir_path.iterdir()] diff --git a/tests/mlia/test_backend_system.py b/tests/test_backend_system.py index 21187ff..13347c6 100644 --- a/tests/mlia/test_backend_system.py +++ b/tests/test_backend_system.py @@ -17,24 +17,12 @@ from mlia.backend.common import Command from mlia.backend.common import ConfigurationException from mlia.backend.common import Param from mlia.backend.common import UserParamConfig -from mlia.backend.config import LocalProtocolConfig -from mlia.backend.config import ProtocolConfig -from mlia.backend.config import SSHConfig from mlia.backend.config import SystemConfig -from mlia.backend.controller import SystemController -from mlia.backend.controller import SystemControllerSingleInstance -from mlia.backend.protocol import LocalProtocol -from mlia.backend.protocol import SSHProtocol -from mlia.backend.protocol import SupportsClose -from mlia.backend.protocol import SupportsDeploy -from mlia.backend.system import ControlledSystem from mlia.backend.system import get_available_systems -from mlia.backend.system import get_controller from mlia.backend.system import get_system from mlia.backend.system import install_system from mlia.backend.system import load_system from mlia.backend.system import remove_system -from mlia.backend.system import StandaloneSystem from mlia.backend.system import System @@ -68,9 +56,7 @@ def test_get_available_systems() -> None: def test_get_system() -> None: """Test get_system.""" system1 = get_system("System 1") - assert isinstance(system1, ControlledSystem) - assert system1.connectable is True - assert system1.connection_details() == ("localhost", 8021) + assert isinstance(system1, System) assert system1.name == "System 1" system2 = get_system("System 2") @@ -78,8 +64,10 @@ def test_get_system() -> None: assert system1 != 42 assert system1 != system2 - system = get_system("Unknown system") - assert system is None + with pytest.raises( + ConfigurationException, match="System 'Unknown system' not found." + ): + get_system("Unknown system") @pytest.mark.parametrize( @@ -142,10 +130,9 @@ def test_remove_system(monkeypatch: Any) -> None: mock_remove_backend.assert_called_once() -def test_system(monkeypatch: Any) -> None: +def test_system() -> None: """Test the System class.""" config = SystemConfig(name="System 1") - monkeypatch.setattr("mlia.backend.system.ProtocolFactory", MagicMock()) system = System(config) assert str(system) == "System 1" assert system.name == "System 1" @@ -162,134 +149,34 @@ def test_system_with_empty_parameter_name() -> None: System(bad_config) -def test_system_standalone_run() -> None: - """Test run operation for standalone system.""" +def test_system_run() -> None: + """Test run operation for system.""" system = get_system("System 4") - assert isinstance(system, StandaloneSystem) - - with pytest.raises( - ConfigurationException, match="System .* does not support connections" - ): - system.connection_details() - - with pytest.raises( - ConfigurationException, match="System .* does not support connections" - ): - system.establish_connection() - - assert system.connectable is False + assert isinstance(system, System) system.run("echo 'application run'") -@pytest.mark.parametrize( - "system_name, expected_value", [("System 1", True), ("System 4", False)] -) -def test_system_supports_deploy(system_name: str, expected_value: bool) -> None: - """Test system property supports_deploy.""" - system = get_system(system_name) - if system is None: - pytest.fail("Unable to get system {}".format(system_name)) - assert system.supports_deploy == expected_value - - -@pytest.mark.parametrize( - "mock_protocol", - [ - MagicMock(spec=SSHProtocol), - MagicMock( - spec=SSHProtocol, - **{"close.side_effect": ValueError("Unable to close protocol")} - ), - MagicMock(spec=LocalProtocol), - ], -) -def test_system_start_and_stop(monkeypatch: Any, mock_protocol: MagicMock) -> None: - """Test system start, run commands and stop.""" - monkeypatch.setattr( - "mlia.backend.system.ProtocolFactory.get_protocol", - MagicMock(return_value=mock_protocol), - ) - - system = get_system("System 1") - if system is None: - pytest.fail("Unable to get system") - assert isinstance(system, ControlledSystem) - - with pytest.raises(Exception, match="System has not been started"): - system.stop() - - assert not system.is_running() - assert system.get_output() == ("", "") - system.start(["sleep 10"], False) - assert system.is_running() - system.stop(wait=True) - assert not system.is_running() - assert system.get_output() == ("", "") - - if isinstance(mock_protocol, SupportsClose): - mock_protocol.close.assert_called_once() - - if isinstance(mock_protocol, SSHProtocol): - system.establish_connection() - - def test_system_start_no_config_location() -> None: """Test that system without config location could not start.""" - system = load_system( - SystemConfig( - name="test", - data_transfer=SSHConfig( - protocol="ssh", - username="user", - password="user", - hostname="localhost", - port="123", - ), - ) - ) + system = load_system(SystemConfig(name="test")) - assert isinstance(system, ControlledSystem) + assert isinstance(system, System) with pytest.raises( - ConfigurationException, match="System test has wrong config location" + ConfigurationException, match="System has invalid config location: None" ): - system.start(["sleep 100"]) + system.run("sleep 100") @pytest.mark.parametrize( "config, expected_class, expected_error", [ ( - SystemConfig( - name="test", - data_transfer=SSHConfig( - protocol="ssh", - username="user", - password="user", - hostname="localhost", - port="123", - ), - ), - ControlledSystem, + SystemConfig(name="test"), + System, does_not_raise(), ), - ( - SystemConfig( - name="test", data_transfer=LocalProtocolConfig(protocol="local") - ), - StandaloneSystem, - does_not_raise(), - ), - ( - SystemConfig( - name="test", - data_transfer=ProtocolConfig(protocol="cool_protocol"), # type: ignore - ), - None, - pytest.raises( - Exception, match="Unsupported execution type for protocol cool_protocol" - ), - ), + (SystemConfig(), None, pytest.raises(ConfigurationException)), ], ) def test_load_system( @@ -310,7 +197,6 @@ def test_load_system_populate_shared_params() -> None: load_system( SystemConfig( name="test_system", - data_transfer=LocalProtocolConfig(protocol="local"), user_params={ "shared": [ UserParamConfig( @@ -330,7 +216,6 @@ def test_load_system_populate_shared_params() -> None: load_system( SystemConfig( name="test_system", - data_transfer=LocalProtocolConfig(protocol="local"), user_params={ "shared": [ UserParamConfig( @@ -355,7 +240,6 @@ def test_load_system_populate_shared_params() -> None: system0 = load_system( SystemConfig( name="test_system", - data_transfer=LocalProtocolConfig(protocol="local"), commands={"run": ["run_command"]}, user_params={ "shared": [], @@ -389,7 +273,6 @@ def test_load_system_populate_shared_params() -> None: system1 = load_system( SystemConfig( name="test_system", - data_transfer=LocalProtocolConfig(protocol="local"), user_params={ "shared": [ UserParamConfig( @@ -412,20 +295,7 @@ def test_load_system_populate_shared_params() -> None: }, ) ) - assert len(system1.commands) == 2 - build_command1 = system1.commands["build"] - assert build_command1 == Command( - [], - [ - Param( - "--shared_param1", - "Shared parameter", - ["1", "2", "3"], - "1", - "shared_param1", - ) - ], - ) + assert len(system1.commands) == 1 run_command1 = system1.commands["run"] assert run_command1 == Command( @@ -451,7 +321,6 @@ def test_load_system_populate_shared_params() -> None: system2 = load_system( SystemConfig( name="test_system", - data_transfer=LocalProtocolConfig(protocol="local"), commands={"build": ["build_command"]}, user_params={ "shared": [ @@ -479,15 +348,7 @@ def test_load_system_populate_shared_params() -> None: build_command2 = system2.commands["build"] assert build_command2 == Command( ["build_command"], - [ - Param( - "--shared_param1", - "Shared parameter", - ["1", "2", "3"], - "1", - "shared_param1", - ) - ], + [], ) run_command2 = system1.commands["run"] @@ -510,32 +371,3 @@ def test_load_system_populate_shared_params() -> None: ), ], ) - - -@pytest.mark.parametrize( - "mock_protocol, expected_call_count", - [(MagicMock(spec=SupportsDeploy), 1), (MagicMock(), 0)], -) -def test_system_deploy_data( - monkeypatch: Any, mock_protocol: MagicMock, expected_call_count: int -) -> None: - """Test deploy data functionality.""" - monkeypatch.setattr( - "mlia.backend.system.ProtocolFactory.get_protocol", - MagicMock(return_value=mock_protocol), - ) - - system = ControlledSystem(SystemConfig(name="test")) - system.deploy(Path("some_file"), "some_dest") - - assert mock_protocol.deploy.call_count == expected_call_count - - -@pytest.mark.parametrize( - "single_instance, controller_class", - ((False, SystemController), (True, SystemControllerSingleInstance)), -) -def test_get_controller(single_instance: bool, controller_class: type) -> None: - """Test function get_controller.""" - controller = get_controller(single_instance) - assert isinstance(controller, controller_class) diff --git a/tests/mlia/test_cli_commands.py b/tests/test_cli_commands.py index bf17339..bf17339 100644 --- a/tests/mlia/test_cli_commands.py +++ b/tests/test_cli_commands.py diff --git a/tests/mlia/test_cli_config.py b/tests/test_cli_config.py index 6d19eec..6d19eec 100644 --- a/tests/mlia/test_cli_config.py +++ b/tests/test_cli_config.py diff --git a/tests/mlia/test_cli_helpers.py b/tests/test_cli_helpers.py index 2c52885..2c52885 100644 --- a/tests/mlia/test_cli_helpers.py +++ b/tests/test_cli_helpers.py diff --git a/tests/mlia/test_cli_logging.py b/tests/test_cli_logging.py index 3f59cb6..5d26551 100644 --- a/tests/mlia/test_cli_logging.py +++ b/tests/test_cli_logging.py @@ -8,7 +8,7 @@ from typing import Optional import pytest from mlia.cli.logging import setup_logging -from tests.mlia.utils.logging import clear_loggers +from tests.utils.logging import clear_loggers def teardown_function() -> None: diff --git a/tests/mlia/test_cli_main.py b/tests/test_cli_main.py index a0937d5..28abc7b 100644 --- a/tests/mlia/test_cli_main.py +++ b/tests/test_cli_main.py @@ -17,7 +17,7 @@ import mlia from mlia.cli.main import CommandInfo from mlia.cli.main import main from mlia.core.context import ExecutionContext -from tests.mlia.utils.logging import clear_loggers +from tests.utils.logging import clear_loggers def teardown_function() -> None: diff --git a/tests/mlia/test_cli_options.py b/tests/test_cli_options.py index a441e58..a441e58 100644 --- a/tests/mlia/test_cli_options.py +++ b/tests/test_cli_options.py diff --git a/tests/mlia/test_core_advice_generation.py b/tests/test_core_advice_generation.py index 05db698..05db698 100644 --- a/tests/mlia/test_core_advice_generation.py +++ b/tests/test_core_advice_generation.py diff --git a/tests/mlia/test_core_advisor.py b/tests/test_core_advisor.py index 375ff62..375ff62 100644 --- a/tests/mlia/test_core_advisor.py +++ b/tests/test_core_advisor.py diff --git a/tests/mlia/test_core_context.py b/tests/test_core_context.py index 44eb976..44eb976 100644 --- a/tests/mlia/test_core_context.py +++ b/tests/test_core_context.py diff --git a/tests/mlia/test_core_data_analysis.py b/tests/test_core_data_analysis.py index a782159..a782159 100644 --- a/tests/mlia/test_core_data_analysis.py +++ b/tests/test_core_data_analysis.py diff --git a/tests/mlia/test_core_events.py b/tests/test_core_events.py index faaab7c..faaab7c 100644 --- a/tests/mlia/test_core_events.py +++ b/tests/test_core_events.py diff --git a/tests/mlia/test_core_helpers.py b/tests/test_core_helpers.py index 8577617..8577617 100644 --- a/tests/mlia/test_core_helpers.py +++ b/tests/test_core_helpers.py diff --git a/tests/mlia/test_core_mixins.py b/tests/test_core_mixins.py index d66213d..d66213d 100644 --- a/tests/mlia/test_core_mixins.py +++ b/tests/test_core_mixins.py diff --git a/tests/mlia/test_core_performance.py b/tests/test_core_performance.py index 0d28fe8..0d28fe8 100644 --- a/tests/mlia/test_core_performance.py +++ b/tests/test_core_performance.py diff --git a/tests/mlia/test_core_reporting.py b/tests/test_core_reporting.py index 2f7ec22..2f7ec22 100644 --- a/tests/mlia/test_core_reporting.py +++ b/tests/test_core_reporting.py diff --git a/tests/mlia/test_core_workflow.py b/tests/test_core_workflow.py index 470e572..470e572 100644 --- a/tests/mlia/test_core_workflow.py +++ b/tests/test_core_workflow.py diff --git a/tests/mlia/test_devices_ethosu_advice_generation.py b/tests/test_devices_ethosu_advice_generation.py index 5d37376..5d37376 100644 --- a/tests/mlia/test_devices_ethosu_advice_generation.py +++ b/tests/test_devices_ethosu_advice_generation.py diff --git a/tests/mlia/test_devices_ethosu_advisor.py b/tests/test_devices_ethosu_advisor.py index 74d2408..74d2408 100644 --- a/tests/mlia/test_devices_ethosu_advisor.py +++ b/tests/test_devices_ethosu_advisor.py diff --git a/tests/mlia/test_devices_ethosu_config.py b/tests/test_devices_ethosu_config.py index 49c999a..49c999a 100644 --- a/tests/mlia/test_devices_ethosu_config.py +++ b/tests/test_devices_ethosu_config.py diff --git a/tests/mlia/test_devices_ethosu_data_analysis.py b/tests/test_devices_ethosu_data_analysis.py index 4b1d38b..4b1d38b 100644 --- a/tests/mlia/test_devices_ethosu_data_analysis.py +++ b/tests/test_devices_ethosu_data_analysis.py diff --git a/tests/mlia/test_devices_ethosu_data_collection.py b/tests/test_devices_ethosu_data_collection.py index 897cf41..897cf41 100644 --- a/tests/mlia/test_devices_ethosu_data_collection.py +++ b/tests/test_devices_ethosu_data_collection.py diff --git a/tests/mlia/test_devices_ethosu_performance.py b/tests/test_devices_ethosu_performance.py index b3e5298..b3e5298 100644 --- a/tests/mlia/test_devices_ethosu_performance.py +++ b/tests/test_devices_ethosu_performance.py diff --git a/tests/mlia/test_devices_ethosu_reporters.py b/tests/test_devices_ethosu_reporters.py index 0da50e0..0da50e0 100644 --- a/tests/mlia/test_devices_ethosu_reporters.py +++ b/tests/test_devices_ethosu_reporters.py diff --git a/tests/mlia/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py index 1ac9f97..1ac9f97 100644 --- a/tests/mlia/test_nn_tensorflow_config.py +++ b/tests/test_nn_tensorflow_config.py diff --git a/tests/mlia/test_nn_tensorflow_optimizations_clustering.py b/tests/test_nn_tensorflow_optimizations_clustering.py index 9bcf918..c12a1e8 100644 --- a/tests/mlia/test_nn_tensorflow_optimizations_clustering.py +++ b/tests/test_nn_tensorflow_optimizations_clustering.py @@ -16,8 +16,8 @@ from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import save_tflite_model -from tests.mlia.utils.common import get_dataset -from tests.mlia.utils.common import train_model +from tests.utils.common import get_dataset +from tests.utils.common import train_model def _prune_model( diff --git a/tests/mlia/test_nn_tensorflow_optimizations_pruning.py b/tests/test_nn_tensorflow_optimizations_pruning.py index 64030a6..5d92f5e 100644 --- a/tests/mlia/test_nn_tensorflow_optimizations_pruning.py +++ b/tests/test_nn_tensorflow_optimizations_pruning.py @@ -14,8 +14,8 @@ from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import save_tflite_model -from tests.mlia.utils.common import get_dataset -from tests.mlia.utils.common import train_model +from tests.utils.common import get_dataset +from tests.utils.common import train_model def _test_sparsity( diff --git a/tests/mlia/test_nn_tensorflow_optimizations_select.py b/tests/test_nn_tensorflow_optimizations_select.py index 5cac8ba..5cac8ba 100644 --- a/tests/mlia/test_nn_tensorflow_optimizations_select.py +++ b/tests/test_nn_tensorflow_optimizations_select.py diff --git a/tests/mlia/test_nn_tensorflow_tflite_metrics.py b/tests/test_nn_tensorflow_tflite_metrics.py index cf7aaeb..00eacef 100644 --- a/tests/mlia/test_nn_tensorflow_tflite_metrics.py +++ b/tests/test_nn_tensorflow_tflite_metrics.py @@ -77,9 +77,7 @@ class TestTFLiteMetrics: # Check sparsity calculation sparsity_per_layer = metrics.sparsity_per_layer() for name, sparsity in sparsity_per_layer.items(): - assert isclose(sparsity, 0.5), "Layer '{}' has incorrect sparsity.".format( - name - ) + assert isclose(sparsity, 0.5), f"Layer '{name}' has incorrect sparsity." assert isclose(metrics.sparsity_overall(), 0.5) @staticmethod @@ -95,7 +93,7 @@ class TestTFLiteMetrics: for num_unique in num_unique_per_axis: assert ( num_unique == 2 - ), "Layer '{}' has incorrect number of clusters.".format(name) + ), f"Layer '{name}' has incorrect number of clusters." # NUM_CLUSTERS_HISTOGRAM hists = metrics.num_unique_weights(ReportClusterMode.NUM_CLUSTERS_HISTOGRAM) assert hists @@ -105,15 +103,13 @@ class TestTFLiteMetrics: # The histogram starts with the bin for for num_clusters == 1 num_clusters = idx + 1 msg = ( - "Histogram of layer '{}': There are {} axes with {} " - "clusters".format(name, num_axes, num_clusters) + f"Histogram of layer '{name}': There are {num_axes} axes " + f"with {num_clusters} clusters" ) if num_clusters == 2: - assert num_axes > 0, "{}, but there should be at least one.".format( - msg - ) + assert num_axes > 0, f"{msg}, but there should be at least one." else: - assert num_axes == 0, "{}, but there should be none.".format(msg) + assert num_axes == 0, f"{msg}, but there should be none." @staticmethod @pytest.mark.parametrize("report_sparsity", (False, True)) diff --git a/tests/mlia/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py index 6d27299..6d27299 100644 --- a/tests/mlia/test_nn_tensorflow_utils.py +++ b/tests/test_nn_tensorflow_utils.py diff --git a/tests/mlia/test_resources/application_config.json b/tests/test_resources/application_config.json index 2dfcfec..8c5d2b5 100644 --- a/tests/mlia/test_resources/application_config.json +++ b/tests/test_resources/application_config.json @@ -6,7 +6,6 @@ "system_1", "system_2" ], - "build_dir": "build_dir_11", "commands": { "clean": [ "clean_cmd_11" @@ -56,7 +55,6 @@ "supported_systems": [ "system_2" ], - "build_dir": "build_dir_21", "commands": { "clean": [ "clean_cmd_21" diff --git a/tests/mlia/test_resources/application_config.json.license b/tests/test_resources/application_config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/application_config.json.license +++ b/tests/test_resources/application_config.json.license diff --git a/tests/mlia/test_resources/backends/applications/application1/aiet-config.json b/tests/test_resources/backends/applications/application1/backend-config.json index 97f0401..96d5420 100644 --- a/tests/mlia/test_resources/backends/applications/application1/aiet-config.json +++ b/tests/test_resources/backends/applications/application1/backend-config.json @@ -7,7 +7,6 @@ "name": "System 1" } ], - "build_dir": "build", "commands": { "clean": [ "echo 'clean'" diff --git a/tests/mlia/test_resources/backends/applications/application1/aiet-config.json.license b/tests/test_resources/backends/applications/application1/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/backends/applications/application1/aiet-config.json.license +++ b/tests/test_resources/backends/applications/application1/backend-config.json.license diff --git a/tests/mlia/test_resources/backends/applications/application2/aiet-config.json b/tests/test_resources/backends/applications/application2/backend-config.json index e9122d3..3a3969a 100644 --- a/tests/mlia/test_resources/backends/applications/application2/aiet-config.json +++ b/tests/test_resources/backends/applications/application2/backend-config.json @@ -7,7 +7,6 @@ "name": "System 2" } ], - "build_dir": "build", "commands": { "clean": [ "echo 'clean'" diff --git a/tests/mlia/test_resources/backends/applications/application2/aiet-config.json.license b/tests/test_resources/backends/applications/application2/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/backends/applications/application2/aiet-config.json.license +++ b/tests/test_resources/backends/applications/application2/backend-config.json.license diff --git a/tests/mlia/test_resources/backends/applications/application3/readme.txt b/tests/test_resources/backends/applications/application3/readme.txt index 8c72c05..8c72c05 100644 --- a/tests/mlia/test_resources/backends/applications/application3/readme.txt +++ b/tests/test_resources/backends/applications/application3/readme.txt diff --git a/tests/mlia/test_resources/backends/applications/application4/aiet-config.json b/tests/test_resources/backends/applications/application4/backend-config.json index ffb5746..d4362be 100644 --- a/tests/mlia/test_resources/backends/applications/application4/aiet-config.json +++ b/tests/test_resources/backends/applications/application4/backend-config.json @@ -2,23 +2,23 @@ { "name": "application_4", "description": "This is application 4", - "build_dir": "build", + "variables": { + "build_dir": "build" + }, "supported_systems": [ { "name": "System 4" } ], "commands": { - "build": [ - "cp ../hello_app.txt .", - "echo '{user_params:0}' > params.txt" - ], "run": [ - "cat {application.build_dir}/hello_app.txt" + "cp {application.config_dir}/hello_app.txt {system.config_dir}", + "echo '{user_params:0}' > {system.config_dir}/params.txt", + "cat hello_app.txt" ] }, "user_params": { - "build": [ + "run": [ { "name": "--app", "description": "Sample command param", @@ -29,8 +29,7 @@ ], "default_value": "application1" } - ], - "run": [] + ] } } ] diff --git a/tests/mlia/test_resources/backends/applications/application4/aiet-config.json.license b/tests/test_resources/backends/applications/application4/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/backends/applications/application4/aiet-config.json.license +++ b/tests/test_resources/backends/applications/application4/backend-config.json.license diff --git a/tests/mlia/test_resources/backends/applications/application4/hello_app.txt b/tests/test_resources/backends/applications/application4/hello_app.txt index 2ec0d1d..2ec0d1d 100644 --- a/tests/mlia/test_resources/backends/applications/application4/hello_app.txt +++ b/tests/test_resources/backends/applications/application4/hello_app.txt diff --git a/tests/mlia/test_resources/backends/applications/application5/aiet-config.json b/tests/test_resources/backends/applications/application5/backend-config.json index 5269409..219494c 100644 --- a/tests/mlia/test_resources/backends/applications/application5/aiet-config.json +++ b/tests/test_resources/backends/applications/application5/backend-config.json @@ -2,11 +2,9 @@ { "name": "application_5", "description": "This is application 5", - "build_dir": "default_build_dir", "supported_systems": [ { - "name": "System 1", - "lock": false + "name": "System 1" }, { "name": "System 2" @@ -16,7 +14,6 @@ "var1": "value1", "var2": "value2" }, - "lock": true, "commands": { "build": [ "default build command" @@ -36,7 +33,6 @@ "supported_systems": [ { "name": "System 1", - "build_dir": "build_5A", "variables": { "var1": "new value1" } @@ -46,7 +42,6 @@ "variables": { "var2": "new value2" }, - "lock": true, "commands": { "run": [ "run command on system 2" @@ -58,7 +53,6 @@ "var1": "value1", "var2": "value2" }, - "build_dir": "build", "commands": { "build": [ "default build command" @@ -78,25 +72,9 @@ "supported_systems": [ { "name": "System 1", - "build_dir": "build_5B", "variables": { "var1": "value for var1 System1", "var2": "value for var2 System1" - }, - "user_params": { - "build": [ - { - "name": "--param_5B", - "description": "Sample command param", - "values": [ - "value1", - "value2", - "value3" - ], - "default_value": "value1", - "alias": "param1" - } - ] } }, { @@ -106,9 +84,6 @@ "var2": "value for var2 System2" }, "commands": { - "build": [ - "build command on system 2 with {variables:var1} {user_params:param1}" - ], "run": [ "run command on system 2" ] @@ -118,7 +93,6 @@ } } ], - "build_dir": "build", "commands": { "build": [ "default build command with {variables:var1}" diff --git a/tests/mlia/test_resources/backends/applications/application5/aiet-config.json.license b/tests/test_resources/backends/applications/application5/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/backends/applications/application5/aiet-config.json.license +++ b/tests/test_resources/backends/applications/application5/backend-config.json.license diff --git a/tests/mlia/test_resources/backends/applications/application6/aiet-config.json b/tests/test_resources/backends/applications/application6/backend-config.json index 56ad807..81afebd 100644 --- a/tests/mlia/test_resources/backends/applications/application6/aiet-config.json +++ b/tests/test_resources/backends/applications/application6/backend-config.json @@ -7,7 +7,6 @@ "name": "System 6" } ], - "build_dir": "build", "commands": { "clean": [ "echo 'clean'" diff --git a/tests/mlia/test_resources/backends/applications/application6/aiet-config.json.license b/tests/test_resources/backends/applications/application6/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/backends/applications/application6/aiet-config.json.license +++ b/tests/test_resources/backends/applications/application6/backend-config.json.license diff --git a/tests/mlia/test_resources/backends/applications/readme.txt b/tests/test_resources/backends/applications/readme.txt index a1f8209..a1f8209 100644 --- a/tests/mlia/test_resources/backends/applications/readme.txt +++ b/tests/test_resources/backends/applications/readme.txt diff --git a/tests/mlia/test_resources/backends/systems/system1/aiet-config.json b/tests/test_resources/backends/systems/system1/backend-config.json index 4b5dd19..4454695 100644 --- a/tests/mlia/test_resources/backends/systems/system1/aiet-config.json +++ b/tests/test_resources/backends/systems/system1/backend-config.json @@ -2,14 +2,6 @@ { "name": "System 1", "description": "This is system 1", - "build_dir": "build", - "data_transfer": { - "protocol": "ssh", - "username": "root", - "password": "root", - "hostname": "localhost", - "port": "8021" - }, "commands": { "clean": [ "echo 'clean'" @@ -22,9 +14,6 @@ ], "post_run": [ "echo 'post_run'" - ], - "deploy": [ - "echo 'deploy'" ] }, "user_params": { diff --git a/tests/mlia/test_resources/backends/systems/system1/aiet-config.json.license b/tests/test_resources/backends/systems/system1/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/backends/systems/system1/aiet-config.json.license +++ b/tests/test_resources/backends/systems/system1/backend-config.json.license diff --git a/tests/mlia/test_resources/backends/systems/system1/system_artifact/dummy.txt b/tests/test_resources/backends/systems/system1/system_artifact/dummy.txt index 487e9d8..487e9d8 100644 --- a/tests/mlia/test_resources/backends/systems/system1/system_artifact/dummy.txt +++ b/tests/test_resources/backends/systems/system1/system_artifact/dummy.txt diff --git a/tests/mlia/test_resources/backends/systems/system2/aiet-config.json b/tests/test_resources/backends/systems/system2/backend-config.json index a9e0eb3..3359d3d 100644 --- a/tests/mlia/test_resources/backends/systems/system2/aiet-config.json +++ b/tests/test_resources/backends/systems/system2/backend-config.json @@ -2,14 +2,6 @@ { "name": "System 2", "description": "This is system 2", - "build_dir": "build", - "data_transfer": { - "protocol": "ssh", - "username": "root", - "password": "root", - "hostname": "localhost", - "port": "8021" - }, "commands": { "clean": [ "echo 'clean'" diff --git a/tests/mlia/test_resources/backends/systems/system2/aiet-config.json.license b/tests/test_resources/backends/systems/system2/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/backends/systems/system2/aiet-config.json.license +++ b/tests/test_resources/backends/systems/system2/backend-config.json.license diff --git a/tests/mlia/test_resources/backends/systems/system3/readme.txt b/tests/test_resources/backends/systems/system3/readme.txt index aba5a9c..aba5a9c 100644 --- a/tests/mlia/test_resources/backends/systems/system3/readme.txt +++ b/tests/test_resources/backends/systems/system3/readme.txt diff --git a/tests/mlia/test_resources/backends/systems/system4/aiet-config.json b/tests/test_resources/backends/systems/system4/backend-config.json index 7b13160..7701c05 100644 --- a/tests/mlia/test_resources/backends/systems/system4/aiet-config.json +++ b/tests/test_resources/backends/systems/system4/backend-config.json @@ -2,14 +2,12 @@ { "name": "System 4", "description": "This is system 4", - "build_dir": "build", - "data_transfer": { - "protocol": "local" - }, "commands": { "run": [ "echo {application.name}", - "{application.commands.run:0}" + "{application.commands.run:0}", + "{application.commands.run:1}", + "{application.commands.run:2}" ] }, "user_params": { diff --git a/tests/mlia/test_resources/backends/systems/system4/aiet-config.json.license b/tests/test_resources/backends/systems/system4/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/backends/systems/system4/aiet-config.json.license +++ b/tests/test_resources/backends/systems/system4/backend-config.json.license diff --git a/tests/mlia/test_resources/backends/systems/system6/aiet-config.json b/tests/test_resources/backends/systems/system6/backend-config.json index 4242f64..5180799 100644 --- a/tests/mlia/test_resources/backends/systems/system6/aiet-config.json +++ b/tests/test_resources/backends/systems/system6/backend-config.json @@ -2,10 +2,6 @@ { "name": "System 6", "description": "This is system 6", - "build_dir": "build", - "data_transfer": { - "protocol": "local" - }, "variables": { "var1": "{user_params:sys-param1}" }, diff --git a/tests/mlia/test_resources/backends/systems/system6/aiet-config.json.license b/tests/test_resources/backends/systems/system6/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/backends/systems/system6/aiet-config.json.license +++ b/tests/test_resources/backends/systems/system6/backend-config.json.license diff --git a/tests/mlia/test_resources/hello_world.json b/tests/test_resources/hello_world.json index 8a9a448..99e9439 100644 --- a/tests/mlia/test_resources/hello_world.json +++ b/tests/test_resources/hello_world.json @@ -5,7 +5,6 @@ "supported_systems": [ "Dummy System" ], - "build_dir": "build", "deploy_data": [ [ "src", diff --git a/tests/mlia/test_resources/hello_world.json.license b/tests/test_resources/hello_world.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/hello_world.json.license +++ b/tests/test_resources/hello_world.json.license diff --git a/tests/mlia/test_resources/scripts/test_backend_run b/tests/test_resources/scripts/test_backend_run index 548f577..548f577 100755 --- a/tests/mlia/test_resources/scripts/test_backend_run +++ b/tests/test_resources/scripts/test_backend_run diff --git a/tests/mlia/test_resources/scripts/test_backend_run_script.sh b/tests/test_resources/scripts/test_backend_run_script.sh index 548f577..548f577 100644 --- a/tests/mlia/test_resources/scripts/test_backend_run_script.sh +++ b/tests/test_resources/scripts/test_backend_run_script.sh diff --git a/tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json b/tests/test_resources/various/applications/application_with_empty_config/backend-config.json index fe51488..fe51488 100644 --- a/tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json +++ b/tests/test_resources/various/applications/application_with_empty_config/backend-config.json diff --git a/tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json.license b/tests/test_resources/various/applications/application_with_empty_config/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json.license +++ b/tests/test_resources/various/applications/application_with_empty_config/backend-config.json.license diff --git a/tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json b/tests/test_resources/various/applications/application_with_valid_config/backend-config.json index ff1cf1a..a457d9b 100644 --- a/tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json +++ b/tests/test_resources/various/applications/application_with_valid_config/backend-config.json @@ -2,7 +2,9 @@ { "name": "test_application", "description": "This is test_application", - "build_dir": "build", + "variables": { + "build_dir": "build" + }, "supported_systems": [ { "name": "System 4" @@ -13,7 +15,7 @@ "cp ../hello_app.txt ." ], "run": [ - "{application.build_dir}/hello_app.txt" + "{application.variables:build_dir}/hello_app.txt" ] }, "user_params": { diff --git a/tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json.license b/tests/test_resources/various/applications/application_with_valid_config/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json.license +++ b/tests/test_resources/various/applications/application_with_valid_config/backend-config.json.license diff --git a/tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json b/tests/test_resources/various/applications/application_with_wrong_config1/backend-config.json index 724b31b..724b31b 100644 --- a/tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json +++ b/tests/test_resources/various/applications/application_with_wrong_config1/backend-config.json diff --git a/tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license b/tests/test_resources/various/applications/application_with_wrong_config1/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license +++ b/tests/test_resources/various/applications/application_with_wrong_config1/backend-config.json.license diff --git a/tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json b/tests/test_resources/various/applications/application_with_wrong_config2/backend-config.json index 1ebb29c..b64e6f8 100644 --- a/tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json +++ b/tests/test_resources/various/applications/application_with_wrong_config2/backend-config.json @@ -2,13 +2,15 @@ { "name": "test_application", "description": "This is test_application", - "build_dir": "build", + "variables": { + "build_dir": "build" + }, "commands": { "build": [ "cp ../hello_app.txt ." ], "run": [ - "{application.build_dir}/hello_app.txt" + "{application.variables:build_dir}/hello_app.txt" ] }, "user_params": { diff --git a/tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license b/tests/test_resources/various/applications/application_with_wrong_config2/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license +++ b/tests/test_resources/various/applications/application_with_wrong_config2/backend-config.json.license diff --git a/tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json b/tests/test_resources/various/applications/application_with_wrong_config3/backend-config.json index 410d12d..4a70cdd 100644 --- a/tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json +++ b/tests/test_resources/various/applications/application_with_wrong_config3/backend-config.json @@ -2,7 +2,9 @@ { "name": "test_application", "description": "This is test_application", - "build_dir": "build", + "variables": { + "build_dir": "build" + }, "supported_systems": [ { "anme": "System 4" @@ -13,7 +15,7 @@ "cp ../hello_app.txt ." ], "run": [ - "{application.build_dir}/hello_app.txt" + "{application.variables:build_dir}/hello_app.txt" ] }, "user_params": { diff --git a/tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license b/tests/test_resources/various/applications/application_with_wrong_config3/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license +++ b/tests/test_resources/various/applications/application_with_wrong_config3/backend-config.json.license diff --git a/tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json b/tests/test_resources/various/systems/system_with_empty_config/backend-config.json index fe51488..fe51488 100644 --- a/tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json +++ b/tests/test_resources/various/systems/system_with_empty_config/backend-config.json diff --git a/tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json.license b/tests/test_resources/various/systems/system_with_empty_config/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json.license +++ b/tests/test_resources/various/systems/system_with_empty_config/backend-config.json.license diff --git a/tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json b/tests/test_resources/various/systems/system_with_valid_config/backend-config.json index 20142e9..83c3025 100644 --- a/tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json +++ b/tests/test_resources/various/systems/system_with_valid_config/backend-config.json @@ -2,10 +2,6 @@ { "name": "Test system", "description": "This is a test system", - "build_dir": "build", - "data_transfer": { - "protocol": "local" - }, "commands": { "run": [] }, diff --git a/tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json.license b/tests/test_resources/various/systems/system_with_valid_config/backend-config.json.license index 9b83bfc..9b83bfc 100644 --- a/tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json.license +++ b/tests/test_resources/various/systems/system_with_valid_config/backend-config.json.license diff --git a/tests/mlia/test_resources/vela/sample_vela.ini b/tests/test_resources/vela/sample_vela.ini index c992458..c992458 100644 --- a/tests/mlia/test_resources/vela/sample_vela.ini +++ b/tests/test_resources/vela/sample_vela.ini diff --git a/tests/mlia/test_tools_metadata_common.py b/tests/test_tools_metadata_common.py index 7663b83..7663b83 100644 --- a/tests/mlia/test_tools_metadata_common.py +++ b/tests/test_tools_metadata_common.py diff --git a/tests/mlia/test_tools_metadata_corstone.py b/tests/test_tools_metadata_corstone.py index 017d0c7..017d0c7 100644 --- a/tests/mlia/test_tools_metadata_corstone.py +++ b/tests/test_tools_metadata_corstone.py diff --git a/tests/mlia/test_tools_vela_wrapper.py b/tests/test_tools_vela_wrapper.py index 875d2ff..0efcb0f 100644 --- a/tests/mlia/test_tools_vela_wrapper.py +++ b/tests/test_tools_vela_wrapper.py @@ -20,7 +20,7 @@ from mlia.tools.vela_wrapper import PerformanceMetrics from mlia.tools.vela_wrapper import supported_operators from mlia.tools.vela_wrapper import VelaCompiler from mlia.tools.vela_wrapper import VelaCompilerOptions -from mlia.utils.proc import working_directory +from mlia.utils.filesystem import working_directory def test_default_vela_compiler() -> None: diff --git a/tests/mlia/test_utils_console.py b/tests/test_utils_console.py index 36975f8..36975f8 100644 --- a/tests/mlia/test_utils_console.py +++ b/tests/test_utils_console.py diff --git a/tests/mlia/test_utils_download.py b/tests/test_utils_download.py index 4f8e2dc..4f8e2dc 100644 --- a/tests/mlia/test_utils_download.py +++ b/tests/test_utils_download.py diff --git a/tests/mlia/test_utils_filesystem.py b/tests/test_utils_filesystem.py index 4d8d955..7cf32e7 100644 --- a/tests/mlia/test_utils_filesystem.py +++ b/tests/test_utils_filesystem.py @@ -20,6 +20,7 @@ from mlia.utils.filesystem import get_vela_config from mlia.utils.filesystem import sha256 from mlia.utils.filesystem import temp_directory from mlia.utils.filesystem import temp_file +from mlia.utils.filesystem import working_directory def test_get_mlia_resources() -> None: @@ -164,3 +165,27 @@ def test_copy_all(tmp_path: Path) -> None: assert (dest_dir / sample.name).is_file() assert (dest_dir / sample_nested_file.name).is_file() + + +@pytest.mark.parametrize( + "should_exist, create_dir", + [ + [True, False], + [False, True], + ], +) +def test_working_directory_context_manager( + tmp_path: Path, should_exist: bool, create_dir: bool +) -> None: + """Test working_directory context manager.""" + prev_wd = Path.cwd() + + working_dir = tmp_path / "work_dir" + if should_exist: + working_dir.mkdir() + + with working_directory(working_dir, create_dir=create_dir) as current_working_dir: + assert current_working_dir.is_dir() + assert Path.cwd() == current_working_dir + + assert Path.cwd() == prev_wd diff --git a/tests/mlia/test_utils_logging.py b/tests/test_utils_logging.py index 75ebceb..75ebceb 100644 --- a/tests/mlia/test_utils_logging.py +++ b/tests/test_utils_logging.py diff --git a/tests/mlia/test_utils_misc.py b/tests/test_utils_misc.py index 011d09e..011d09e 100644 --- a/tests/mlia/test_utils_misc.py +++ b/tests/test_utils_misc.py diff --git a/tests/mlia/test_utils_types.py b/tests/test_utils_types.py index 4909efe..4909efe 100644 --- a/tests/mlia/test_utils_types.py +++ b/tests/test_utils_types.py diff --git a/tests/mlia/utils/__init__.py b/tests/utils/__init__.py index 27166ef..27166ef 100644 --- a/tests/mlia/utils/__init__.py +++ b/tests/utils/__init__.py diff --git a/tests/mlia/utils/common.py b/tests/utils/common.py index 932343e..932343e 100644 --- a/tests/mlia/utils/common.py +++ b/tests/utils/common.py diff --git a/tests/mlia/utils/logging.py b/tests/utils/logging.py index d223fb2..d223fb2 100644 --- a/tests/mlia/utils/logging.py +++ b/tests/utils/logging.py |