From c9b4089b3037b5943565d76242d3016b8776f8d2 Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Tue, 28 Jun 2022 10:29:35 +0100 Subject: MLIA-546 Merge AIET into MLIA Merge the deprecated AIET interface for backend execution into MLIA: - Execute backends directly (without subprocess and the aiet CLI) - Fix issues with the unit tests - Remove src/aiet and tests/aiet - Re-factor code to replace 'aiet' with 'backend' - Adapt and improve unit tests after re-factoring - Remove dependencies that are not needed anymore (click and cloup) Change-Id: I450734c6a3f705ba9afde41862b29e797e511f7c --- src/mlia/backend/__init__.py | 3 + src/mlia/backend/application.py | 187 +++++ src/mlia/backend/common.py | 532 ++++++++++++++ src/mlia/backend/config.py | 93 +++ src/mlia/backend/controller.py | 134 ++++ src/mlia/backend/execution.py | 779 +++++++++++++++++++++ src/mlia/backend/fs.py | 115 +++ src/mlia/backend/manager.py | 447 ++++++++++++ src/mlia/backend/output_parser.py | 176 +++++ src/mlia/backend/proc.py | 283 ++++++++ src/mlia/backend/protocol.py | 325 +++++++++ src/mlia/backend/source.py | 209 ++++++ src/mlia/backend/system.py | 289 ++++++++ src/mlia/cli/config.py | 6 +- src/mlia/devices/ethosu/performance.py | 37 +- .../resources/aiet/applications/APPLICATIONS.txt | 5 +- src/mlia/resources/aiet/systems/SYSTEMS.txt | 3 +- .../resources/backends/applications/.gitignore | 6 + src/mlia/resources/backends/systems/.gitignore | 6 + src/mlia/tools/aiet_wrapper.py | 435 ------------ src/mlia/tools/metadata/corstone.py | 61 +- src/mlia/utils/proc.py | 20 +- 22 files changed, 3645 insertions(+), 506 deletions(-) create mode 100644 src/mlia/backend/__init__.py create mode 100644 src/mlia/backend/application.py create mode 100644 src/mlia/backend/common.py create mode 100644 src/mlia/backend/config.py create mode 100644 src/mlia/backend/controller.py create mode 100644 src/mlia/backend/execution.py create mode 100644 src/mlia/backend/fs.py create mode 100644 src/mlia/backend/manager.py create mode 100644 src/mlia/backend/output_parser.py create mode 100644 src/mlia/backend/proc.py create mode 100644 src/mlia/backend/protocol.py create mode 100644 src/mlia/backend/source.py create mode 100644 src/mlia/backend/system.py create mode 100644 src/mlia/resources/backends/applications/.gitignore create mode 100644 src/mlia/resources/backends/systems/.gitignore delete mode 100644 src/mlia/tools/aiet_wrapper.py (limited to 'src/mlia') diff --git a/src/mlia/backend/__init__.py b/src/mlia/backend/__init__.py new file mode 100644 index 0000000..3d60372 --- /dev/null +++ b/src/mlia/backend/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Backend module.""" diff --git a/src/mlia/backend/application.py b/src/mlia/backend/application.py new file mode 100644 index 0000000..eb85212 --- /dev/null +++ b/src/mlia/backend/application.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Application backend module.""" +import re +from pathlib import Path +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional + +from mlia.backend.common import Backend +from mlia.backend.common import ConfigurationException +from mlia.backend.common import DataPaths +from mlia.backend.common import get_backend_configs +from mlia.backend.common import get_backend_directories +from mlia.backend.common import load_application_or_tool_configs +from mlia.backend.common import load_config +from mlia.backend.common import remove_backend +from mlia.backend.config import ApplicationConfig +from mlia.backend.config import ExtendedApplicationConfig +from mlia.backend.fs import get_backends_path +from mlia.backend.source import create_destination_and_install +from mlia.backend.source import get_source + + +def get_available_application_directory_names() -> List[str]: + """Return a list of directory names for all available applications.""" + return [entry.name for entry in get_backend_directories("applications")] + + +def get_available_applications() -> List["Application"]: + """Return a list with all available applications.""" + available_applications = [] + for config_json in get_backend_configs("applications"): + config_entries = cast(List[ExtendedApplicationConfig], load_config(config_json)) + for config_entry in config_entries: + config_entry["config_location"] = config_json.parent.absolute() + applications = load_applications(config_entry) + available_applications += applications + + return sorted(available_applications, key=lambda application: application.name) + + +def get_application( + application_name: str, system_name: Optional[str] = None +) -> List["Application"]: + """Return a list of application instances with provided name.""" + return [ + application + for application in get_available_applications() + if application.name == application_name + and (not system_name or application.can_run_on(system_name)) + ] + + +def install_application(source_path: Path) -> None: + """Install application.""" + try: + source = get_source(source_path) + config = cast(List[ExtendedApplicationConfig], source.config()) + applications_to_install = [ + s for entry in config for s in load_applications(entry) + ] + except Exception as error: + raise ConfigurationException("Unable to read application definition") from error + + if not applications_to_install: + raise ConfigurationException("No application definition found") + + available_applications = get_available_applications() + already_installed = [ + s for s in applications_to_install if s in available_applications + ] + if already_installed: + names = {application.name for application in already_installed} + raise ConfigurationException( + "Applications [{}] are already installed".format(",".join(names)) + ) + + create_destination_and_install(source, get_backends_path("applications")) + + +def remove_application(directory_name: str) -> None: + """Remove application directory.""" + remove_backend(directory_name, "applications") + + +def get_unique_application_names(system_name: Optional[str] = None) -> List[str]: + """Extract a list of unique application names of all application available.""" + return list( + set( + application.name + for application in get_available_applications() + if not system_name or application.can_run_on(system_name) + ) + ) + + +class Application(Backend): + """Class for representing a single application component.""" + + def __init__(self, config: ApplicationConfig) -> None: + """Construct a Application instance from a dict.""" + super().__init__(config) + + self.supported_systems = config.get("supported_systems", []) + self.deploy_data = config.get("deploy_data", []) + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Application): + return False + + return ( + super().__eq__(other) + and self.name == other.name + and set(self.supported_systems) == set(other.supported_systems) + ) + + def can_run_on(self, system_name: str) -> bool: + """Check if the application can run on the system passed as argument.""" + return system_name in self.supported_systems + + def get_deploy_data(self) -> List[DataPaths]: + """Validate and return data specified in the config file.""" + if self.config_location is None: + raise ConfigurationException( + "Unable to get application {} config location".format(self.name) + ) + + deploy_data = [] + for item in self.deploy_data: + src, dst = item + src_full_path = self.config_location / src + assert src_full_path.exists(), "{} does not exists".format(src_full_path) + deploy_data.append(DataPaths(src_full_path, dst)) + return deploy_data + + def get_details(self) -> Dict[str, Any]: + """Return dictionary with information about the Application instance.""" + output = { + "type": "application", + "name": self.name, + "description": self.description, + "supported_systems": self.supported_systems, + "commands": self._get_command_details(), + } + + return output + + def remove_unused_params(self) -> None: + """Remove unused params in commands. + + After merging default and system related configuration application + could have parameters that are not being used in commands. They + should be removed. + """ + for command in self.commands.values(): + indexes_or_aliases = [ + m + for cmd_str in command.command_strings + for m in re.findall(r"{user_params:(?P\w+)}", cmd_str) + ] + + only_aliases = all(not item.isnumeric() for item in indexes_or_aliases) + if only_aliases: + used_params = [ + param + for param in command.params + if param.alias in indexes_or_aliases + ] + command.params = used_params + + +def load_applications(config: ExtendedApplicationConfig) -> List[Application]: + """Load application. + + Application configuration could contain different parameters/commands for different + supported systems. For each supported system this function will return separate + Application instance with appropriate configuration. + """ + configs = load_application_or_tool_configs(config, ApplicationConfig) + applications = [Application(cfg) for cfg in configs] + for application in applications: + application.remove_unused_params() + return applications diff --git a/src/mlia/backend/common.py b/src/mlia/backend/common.py new file mode 100644 index 0000000..2bbb9d3 --- /dev/null +++ b/src/mlia/backend/common.py @@ -0,0 +1,532 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain all common functions for the backends.""" +import json +import logging +import re +from abc import ABC +from collections import Counter +from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Final +from typing import IO +from typing import Iterable +from typing import List +from typing import Match +from typing import NamedTuple +from typing import Optional +from typing import Pattern +from typing import Tuple +from typing import Type +from typing import Union + +from mlia.backend.config import BackendConfig +from mlia.backend.config import BaseBackendConfig +from mlia.backend.config import NamedExecutionConfig +from mlia.backend.config import UserParamConfig +from mlia.backend.config import UserParamsConfig +from mlia.backend.fs import get_backends_path +from mlia.backend.fs import remove_resource +from mlia.backend.fs import ResourceType + + +BACKEND_CONFIG_FILE: Final[str] = "aiet-config.json" + + +class ConfigurationException(Exception): + """Configuration exception.""" + + +def get_backend_config(dir_path: Path) -> Path: + """Get path to backendir configuration file.""" + return dir_path / BACKEND_CONFIG_FILE + + +def get_backend_configs(resource_type: ResourceType) -> Iterable[Path]: + """Get path to the backend configs for provided resource_type.""" + return ( + get_backend_config(entry) for entry in get_backend_directories(resource_type) + ) + + +def get_backend_directories(resource_type: ResourceType) -> Iterable[Path]: + """Get path to the backend directories for provided resource_type.""" + return ( + entry + for entry in get_backends_path(resource_type).iterdir() + if is_backend_directory(entry) + ) + + +def is_backend_directory(dir_path: Path) -> bool: + """Check if path is backend's configuration directory.""" + return dir_path.is_dir() and get_backend_config(dir_path).is_file() + + +def remove_backend(directory_name: str, resource_type: ResourceType) -> None: + """Remove backend with provided type and directory_name.""" + if not directory_name: + raise Exception("No directory name provided") + + remove_resource(directory_name, resource_type) + + +def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig: + """Return a loaded json file.""" + if config is None: + raise Exception("Unable to read config") + + if isinstance(config, Path): + with config.open() as json_file: + return cast(BackendConfig, json.load(json_file)) + + return cast(BackendConfig, json.load(config)) + + +def parse_raw_parameter(parameter: str) -> Tuple[str, Optional[str]]: + """Split the parameter string in name and optional value. + + It manages the following cases: + --param=1 -> --param, 1 + --param 1 -> --param, 1 + --flag -> --flag, None + """ + data = re.split(" |=", parameter) + if len(data) == 1: + param_name = data[0] + param_value = None + else: + param_name = " ".join(data[0:-1]) + param_value = data[-1] + return param_name, param_value + + +class DataPaths(NamedTuple): + """DataPaths class.""" + + src: Path + dst: str + + +class Backend(ABC): + """Backend class.""" + + # pylint: disable=too-many-instance-attributes + + def __init__(self, config: BaseBackendConfig): + """Initialize backend.""" + name = config.get("name") + if not name: + raise ConfigurationException("Name is empty") + + self.name = name + self.description = config.get("description", "") + self.config_location = config.get("config_location") + self.variables = config.get("variables", {}) + self.build_dir = config.get("build_dir") + self.lock = config.get("lock", False) + if self.build_dir: + self.build_dir = self._substitute_variables(self.build_dir) + self.annotations = config.get("annotations", {}) + + self._parse_commands_and_params(config) + + def validate_parameter(self, command_name: str, parameter: str) -> bool: + """Validate the parameter string against the application configuration. + + We take the parameter string, extract the parameter name/value and + check them against the current configuration. + """ + param_name, param_value = parse_raw_parameter(parameter) + valid_param_name = valid_param_value = False + + command = self.commands.get(command_name) + if not command: + raise AttributeError("Unknown command: '{}'".format(command_name)) + + # Iterate over all available parameters until we have a match. + for param in command.params: + if self._same_parameter(param_name, param): + valid_param_name = True + # This is a non-empty list + if param.values: + # We check if the value is allowed in the configuration + valid_param_value = param_value in param.values + else: + # In this case we don't validate the value and accept + # whatever we have set. + valid_param_value = True + break + + return valid_param_name and valid_param_value + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Backend): + return False + + return ( + self.name == other.name + and self.description == other.description + and self.commands == other.commands + ) + + def __repr__(self) -> str: + """Represent the Backend instance by its name.""" + return self.name + + def _parse_commands_and_params(self, config: BaseBackendConfig) -> None: + """Parse commands and user parameters.""" + self.commands: Dict[str, Command] = {} + + commands = config.get("commands") + if commands: + params = config.get("user_params") + + for command_name in commands.keys(): + command_params = self._parse_params(params, command_name) + command_strings = [ + self._substitute_variables(cmd) + for cmd in commands.get(command_name, []) + ] + self.commands[command_name] = Command(command_strings, command_params) + + def _substitute_variables(self, str_val: str) -> str: + """Substitute variables in string. + + Variables is being substituted at backend's creation stage because + they could contain references to other params which will be + resolved later. + """ + if not str_val: + return str_val + + var_pattern: Final[Pattern] = re.compile(r"{variables:(?P\w+)}") + + def var_value(match: Match) -> str: + var_name = match["var_name"] + if var_name not in self.variables: + raise ConfigurationException("Unknown variable {}".format(var_name)) + + return self.variables[var_name] + + return var_pattern.sub(var_value, str_val) # type: ignore + + @classmethod + def _parse_params( + cls, params: Optional[UserParamsConfig], command: str + ) -> List["Param"]: + if not params: + return [] + + return [cls._parse_param(p) for p in params.get(command, [])] + + @classmethod + def _parse_param(cls, param: UserParamConfig) -> "Param": + """Parse a single parameter.""" + name = param.get("name") + if name is not None and not name: + raise ConfigurationException("Parameter has an empty 'name' attribute.") + values = param.get("values", None) + default_value = param.get("default_value", None) + description = param.get("description", "") + alias = param.get("alias") + + return Param( + name=name, + description=description, + values=values, + default_value=default_value, + alias=alias, + ) + + def _get_command_details(self) -> Dict: + command_details = { + command_name: command.get_details() + for command_name, command in self.commands.items() + } + return command_details + + def _get_user_param_value( + self, user_params: List[str], param: "Param" + ) -> Optional[str]: + """Get the user-specified value of a parameter.""" + for user_param in user_params: + user_param_name, user_param_value = parse_raw_parameter(user_param) + if user_param_name == param.name: + warn_message = ( + "The direct use of parameter name is deprecated" + " and might be removed in the future.\n" + f"Please use alias '{param.alias}' instead of " + "'{user_param_name}' to provide the parameter." + ) + logging.warning(warn_message) + + if self._same_parameter(user_param_name, param): + return user_param_value + + return None + + @staticmethod + def _same_parameter(user_param_name_or_alias: str, param: "Param") -> bool: + """Compare user parameter name with param name or alias.""" + # Strip the "=" sign in the param_name. This is needed just for + # comparison with the parameters passed by the user. + # The equal sign needs to be honoured when re-building the + # parameter back. + param_name = None if not param.name else param.name.rstrip("=") + return user_param_name_or_alias in [param_name, param.alias] + + def resolved_parameters( + self, command_name: str, user_params: List[str] + ) -> List[Tuple[Optional[str], "Param"]]: + """Return list of parameters with values.""" + result: List[Tuple[Optional[str], "Param"]] = [] + command = self.commands.get(command_name) + if not command: + return result + + for param in command.params: + value = self._get_user_param_value(user_params, param) + if not value: + value = param.default_value + result.append((value, param)) + + return result + + def build_command( + self, + command_name: str, + user_params: List[str], + param_resolver: Callable[[str, str, List[Tuple[Optional[str], "Param"]]], str], + ) -> List[str]: + """ + Return a list of executable command strings. + + Given a command and associated parameters, returns a list of executable command + strings. + """ + command = self.commands.get(command_name) + if not command: + raise ConfigurationException( + "Command '{}' could not be found.".format(command_name) + ) + + commands_to_run = [] + + params_values = self.resolved_parameters(command_name, user_params) + for cmd_str in command.command_strings: + cmd_str = resolve_all_parameters( + cmd_str, param_resolver, command_name, params_values + ) + commands_to_run.append(cmd_str) + + return commands_to_run + + +class Param: + """Class for representing a generic application parameter.""" + + def __init__( # pylint: disable=too-many-arguments + self, + name: Optional[str], + description: str, + values: Optional[List[str]] = None, + default_value: Optional[str] = None, + alias: Optional[str] = None, + ) -> None: + """Construct a Param instance.""" + if not name and not alias: + raise ConfigurationException( + "Either name, alias or both must be set to identify a parameter." + ) + self.name = name + self.values = values + self.description = description + self.default_value = default_value + self.alias = alias + + def get_details(self) -> Dict: + """Return a dictionary with all relevant information of a Param.""" + return {key: value for key, value in self.__dict__.items() if value} + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Param): + return False + + return ( + self.name == other.name + and self.values == other.values + and self.default_value == other.default_value + and self.description == other.description + ) + + +class Command: + """Class for representing a command.""" + + def __init__( + self, command_strings: List[str], params: Optional[List[Param]] = None + ) -> None: + """Construct a Command instance.""" + self.command_strings = command_strings + + if params: + self.params = params + else: + self.params = [] + + self._validate() + + def _validate(self) -> None: + """Validate command.""" + if not self.params: + return + + aliases = [param.alias for param in self.params if param.alias is not None] + repeated_aliases = [ + alias for alias, count in Counter(aliases).items() if count > 1 + ] + + if repeated_aliases: + raise ConfigurationException( + "Non unique aliases {}".format(", ".join(repeated_aliases)) + ) + + both_name_and_alias = [ + param.name + for param in self.params + if param.name in aliases and param.name != param.alias + ] + if both_name_and_alias: + raise ConfigurationException( + "Aliases {} could not be used as parameter name".format( + ", ".join(both_name_and_alias) + ) + ) + + def get_details(self) -> Dict: + """Return a dictionary with all relevant information of a Command.""" + output = { + "command_strings": self.command_strings, + "user_params": [param.get_details() for param in self.params], + } + return output + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Command): + return False + + return ( + self.command_strings == other.command_strings + and self.params == other.params + ) + + +def resolve_all_parameters( + str_val: str, + param_resolver: Callable[[str, str, List[Tuple[Optional[str], Param]]], str], + command_name: Optional[str] = None, + params_values: Optional[List[Tuple[Optional[str], Param]]] = None, +) -> str: + """Resolve all parameters in the string.""" + if not str_val: + return str_val + + param_pattern: Final[Pattern] = re.compile(r"{(?P[\w.:]+)}") + while param_pattern.findall(str_val): + str_val = param_pattern.sub( + lambda m: param_resolver( + m["param_name"], command_name or "", params_values or [] + ), + str_val, + ) + return str_val + + +def load_application_or_tool_configs( + config: Any, + config_type: Type[Any], + is_system_required: bool = True, +) -> Any: + """Get one config for each system supported by the application/tool. + + The configuration could contain different parameters/commands for different + supported systems. For each supported system this function will return separate + config with appropriate configuration. + """ + merged_configs = [] + supported_systems: Optional[List[NamedExecutionConfig]] = config.get( + "supported_systems" + ) + if not supported_systems: + if is_system_required: + raise ConfigurationException("No supported systems definition provided") + # Create an empty system to be used in the parsing below + supported_systems = [cast(NamedExecutionConfig, {})] + + default_user_params = config.get("user_params", {}) + + def merge_config(system: NamedExecutionConfig) -> Any: + system_name = system.get("name") + if not system_name and is_system_required: + raise ConfigurationException( + "Unable to read supported system definition, name is missed" + ) + + merged_config = config_type(**config) + merged_config["supported_systems"] = [system_name] if system_name else [] + # merge default configuration and specific to the system + merged_config["commands"] = { + **config.get("commands", {}), + **system.get("commands", {}), + } + + params = {} + tool_user_params = system.get("user_params", {}) + command_names = tool_user_params.keys() | default_user_params.keys() + for command_name in command_names: + if command_name not in merged_config["commands"]: + continue + + params_default = default_user_params.get(command_name, []) + params_tool = tool_user_params.get(command_name, []) + if not params_default or not params_tool: + params[command_name] = params_tool or params_default + if params_default and params_tool: + if any(not p.get("alias") for p in params_default): + raise ConfigurationException( + "Default parameters for command {} should have aliases".format( + command_name + ) + ) + if any(not p.get("alias") for p in params_tool): + raise ConfigurationException( + "{} parameters for command {} should have aliases".format( + system_name, command_name + ) + ) + + merged_by_alias = { + **{p.get("alias"): p for p in params_default}, + **{p.get("alias"): p for p in params_tool}, + } + params[command_name] = list(merged_by_alias.values()) + + merged_config["user_params"] = params + merged_config["build_dir"] = system.get("build_dir", config.get("build_dir")) + merged_config["lock"] = system.get("lock", config.get("lock", False)) + merged_config["variables"] = { + **config.get("variables", {}), + **system.get("variables", {}), + } + return merged_config + + merged_configs = [merge_config(system) for system in supported_systems] + + return merged_configs diff --git a/src/mlia/backend/config.py b/src/mlia/backend/config.py new file mode 100644 index 0000000..657adef --- /dev/null +++ b/src/mlia/backend/config.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain definition of backend configuration.""" +from pathlib import Path +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import Tuple +from typing import TypedDict +from typing import Union + + +class UserParamConfig(TypedDict, total=False): + """User parameter configuration.""" + + name: Optional[str] + default_value: str + values: List[str] + description: str + alias: str + + +UserParamsConfig = Dict[str, List[UserParamConfig]] + + +class ExecutionConfig(TypedDict, total=False): + """Execution configuration.""" + + commands: Dict[str, List[str]] + user_params: UserParamsConfig + build_dir: str + variables: Dict[str, str] + lock: bool + + +class NamedExecutionConfig(ExecutionConfig): + """Execution configuration with name.""" + + name: str + + +class BaseBackendConfig(ExecutionConfig, total=False): + """Base backend configuration.""" + + name: str + description: str + config_location: Path + annotations: Dict[str, Union[str, List[str]]] + + +class ApplicationConfig(BaseBackendConfig, total=False): + """Application configuration.""" + + supported_systems: List[str] + deploy_data: List[Tuple[str, str]] + + +class ExtendedApplicationConfig(BaseBackendConfig, total=False): + """Extended application configuration.""" + + supported_systems: List[NamedExecutionConfig] + deploy_data: List[Tuple[str, str]] + + +class ProtocolConfig(TypedDict, total=False): + """Protocol config.""" + + protocol: Literal["local", "ssh"] + + +class SSHConfig(ProtocolConfig, total=False): + """SSH configuration.""" + + username: str + password: str + hostname: str + port: str + + +class LocalProtocolConfig(ProtocolConfig, total=False): + """Local protocol config.""" + + +class SystemConfig(BaseBackendConfig, total=False): + """System configuration.""" + + data_transfer: Union[SSHConfig, LocalProtocolConfig] + reporting: Dict[str, Dict] + + +BackendItemConfig = Union[ApplicationConfig, SystemConfig] +BackendConfig = Union[List[ExtendedApplicationConfig], List[SystemConfig]] diff --git a/src/mlia/backend/controller.py b/src/mlia/backend/controller.py new file mode 100644 index 0000000..f1b68a9 --- /dev/null +++ b/src/mlia/backend/controller.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Controller backend module.""" +import time +from pathlib import Path +from typing import List +from typing import Optional +from typing import Tuple + +import psutil +import sh + +from mlia.backend.common import ConfigurationException +from mlia.backend.fs import read_file_as_string +from mlia.backend.proc import execute_command +from mlia.backend.proc import get_stdout_stderr_paths +from mlia.backend.proc import read_process_info +from mlia.backend.proc import save_process_info +from mlia.backend.proc import terminate_command +from mlia.backend.proc import terminate_external_process + + +class SystemController: + """System controller class.""" + + def __init__(self) -> None: + """Create new instance of service controller.""" + self.cmd: Optional[sh.RunningCommand] = None + self.out_path: Optional[Path] = None + self.err_path: Optional[Path] = None + + def before_start(self) -> None: + """Run actions before system start.""" + + def after_start(self) -> None: + """Run actions after system start.""" + + def start(self, commands: List[str], cwd: Path) -> None: + """Start system.""" + if not isinstance(cwd, Path) or not cwd.is_dir(): + raise ConfigurationException("Wrong working directory {}".format(cwd)) + + if len(commands) != 1: + raise ConfigurationException("System should have only one command to run") + + startup_command = commands[0] + if not startup_command: + raise ConfigurationException("No startup command provided") + + self.before_start() + + self.out_path, self.err_path = get_stdout_stderr_paths(startup_command) + + self.cmd = execute_command( + startup_command, + cwd, + bg=True, + out=str(self.out_path), + err=str(self.err_path), + ) + + self.after_start() + + def stop( + self, wait: bool = False, wait_period: float = 0.5, number_of_attempts: int = 20 + ) -> None: + """Stop system.""" + if self.cmd is not None and self.is_running(): + terminate_command(self.cmd, wait, wait_period, number_of_attempts) + + def is_running(self) -> bool: + """Check if underlying process is running.""" + return self.cmd is not None and self.cmd.is_alive() + + def get_output(self) -> Tuple[str, str]: + """Return application output.""" + if self.cmd is None or self.out_path is None or self.err_path is None: + return ("", "") + + return (read_file_as_string(self.out_path), read_file_as_string(self.err_path)) + + +class SystemControllerSingleInstance(SystemController): + """System controller with support of system's single instance.""" + + def __init__(self, pid_file_path: Optional[Path] = None) -> None: + """Create new instance of the service controller.""" + super().__init__() + self.pid_file_path = pid_file_path + + def before_start(self) -> None: + """Run actions before system start.""" + self._check_if_previous_instance_is_running() + + def after_start(self) -> None: + """Run actions after system start.""" + self._save_process_info() + + def _check_if_previous_instance_is_running(self) -> None: + """Check if another instance of the system is running.""" + process_info = read_process_info(self._pid_file()) + + for item in process_info: + try: + process = psutil.Process(item.pid) + same_process = ( + process.name() == item.name + and process.exe() == item.executable + and process.cwd() == item.cwd + ) + if same_process: + print( + "Stopping previous instance of the system [{}]".format(item.pid) + ) + terminate_external_process(process) + except psutil.NoSuchProcess: + pass + + def _save_process_info(self, wait_period: float = 2) -> None: + """Save information about system's processes.""" + if self.cmd is None or not self.is_running(): + return + + # give some time for the system to start + time.sleep(wait_period) + + save_process_info(self.cmd.process.pid, self._pid_file()) + + def _pid_file(self) -> Path: + """Return path to file which is used for saving process info.""" + if not self.pid_file_path: + raise Exception("No pid file path presented") + + return self.pid_file_path diff --git a/src/mlia/backend/execution.py b/src/mlia/backend/execution.py new file mode 100644 index 0000000..749ccdb --- /dev/null +++ b/src/mlia/backend/execution.py @@ -0,0 +1,779 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Application execution module.""" +import itertools +import json +import random +import re +import string +import sys +import time +from collections import defaultdict +from contextlib import contextmanager +from contextlib import ExitStack +from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import ContextManager +from typing import Dict +from typing import Generator +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TypedDict + +from filelock import FileLock +from filelock import Timeout + +from mlia.backend.application import Application +from mlia.backend.application import get_application +from mlia.backend.common import Backend +from mlia.backend.common import ConfigurationException +from mlia.backend.common import DataPaths +from mlia.backend.common import Param +from mlia.backend.common import parse_raw_parameter +from mlia.backend.common import resolve_all_parameters +from mlia.backend.fs import recreate_directory +from mlia.backend.fs import remove_directory +from mlia.backend.fs import valid_for_filename +from mlia.backend.output_parser import Base64OutputParser +from mlia.backend.output_parser import OutputParser +from mlia.backend.output_parser import RegexOutputParser +from mlia.backend.proc import run_and_wait +from mlia.backend.system import ControlledSystem +from mlia.backend.system import get_system +from mlia.backend.system import StandaloneSystem +from mlia.backend.system import System + + +class AnotherInstanceIsRunningException(Exception): + """Concurrent execution error.""" + + +class ConnectionException(Exception): + """Connection exception.""" + + +class ExecutionParams(TypedDict, total=False): + """Execution parameters.""" + + disable_locking: bool + unique_build_dir: bool + + +class ExecutionContext: + """Command execution context.""" + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + app: Application, + app_params: List[str], + system: Optional[System], + system_params: List[str], + custom_deploy_data: Optional[List[DataPaths]] = None, + execution_params: Optional[ExecutionParams] = None, + report_file: Optional[Path] = None, + ): + """Init execution context.""" + self.app = app + self.app_params = app_params + self.custom_deploy_data = custom_deploy_data or [] + self.system = system + self.system_params = system_params + self.execution_params = execution_params or ExecutionParams() + self.report_file = report_file + + self.reporter: Optional[Reporter] + if self.report_file: + # Create reporter with output parsers + parsers: List[OutputParser] = [] + if system and system.reporting: + # Add RegexOutputParser, if it is configured in the system + parsers.append(RegexOutputParser("system", system.reporting["regex"])) + # Add Base64 parser for applications + parsers.append(Base64OutputParser("application")) + self.reporter = Reporter(parsers=parsers) + else: + self.reporter = None # No reporter needed. + + self.param_resolver = ParamResolver(self) + self._resolved_build_dir: Optional[Path] = None + + self.stdout: Optional[bytearray] = None + self.stderr: Optional[bytearray] = None + + @property + def is_deploy_needed(self) -> bool: + """Check if application requires data deployment.""" + return len(self.app.get_deploy_data()) > 0 or len(self.custom_deploy_data) > 0 + + @property + def is_locking_required(self) -> bool: + """Return true if any form of locking required.""" + return not self._disable_locking() and ( + self.app.lock or (self.system is not None and self.system.lock) + ) + + @property + def is_build_required(self) -> bool: + """Return true if application build required.""" + return "build" in self.app.commands + + @property + def is_unique_build_dir_required(self) -> bool: + """Return true if unique build dir required.""" + return self.execution_params.get("unique_build_dir", False) + + def build_dir(self) -> Path: + """Return resolved application build dir.""" + if self._resolved_build_dir is not None: + return self._resolved_build_dir + + if ( + not isinstance(self.app.config_location, Path) + or not self.app.config_location.is_dir() + ): + raise ConfigurationException( + "Application {} has wrong config location".format(self.app.name) + ) + + _build_dir = self.app.build_dir + if _build_dir: + _build_dir = resolve_all_parameters(_build_dir, self.param_resolver) + + if not _build_dir: + raise ConfigurationException( + "No build directory defined for the app {}".format(self.app.name) + ) + + if self.is_unique_build_dir_required: + random_suffix = "".join( + random.choices(string.ascii_lowercase + string.digits, k=7) + ) + _build_dir = "{}_{}".format(_build_dir, random_suffix) + + self._resolved_build_dir = self.app.config_location / _build_dir + return self._resolved_build_dir + + def _disable_locking(self) -> bool: + """Return true if locking should be disabled.""" + return self.execution_params.get("disable_locking", False) + + +class ParamResolver: + """Parameter resolver.""" + + def __init__(self, context: ExecutionContext): + """Init parameter resolver.""" + self.ctx = context + + @staticmethod + def resolve_user_params( + cmd_name: Optional[str], + index_or_alias: str, + resolved_params: Optional[List[Tuple[Optional[str], Param]]], + ) -> str: + """Resolve user params.""" + if not cmd_name or resolved_params is None: + raise ConfigurationException("Unable to resolve user params") + + param_value: Optional[str] = None + param: Optional[Param] = None + + if index_or_alias.isnumeric(): + i = int(index_or_alias) + if i not in range(len(resolved_params)): + raise ConfigurationException( + "Invalid index {} for user params of command {}".format(i, cmd_name) + ) + param_value, param = resolved_params[i] + else: + for val, par in resolved_params: + if par.alias == index_or_alias: + param_value, param = val, par + break + + if param is None: + raise ConfigurationException( + "No user parameter for command '{}' with alias '{}'.".format( + cmd_name, index_or_alias + ) + ) + + if param_value: + # We need to handle to cases of parameters here: + # 1) Optional parameters (non-positional with a name and value) + # 2) Positional parameters (value only, no name needed) + # Default to empty strings for positional arguments + param_name = "" + separator = "" + if param.name is not None: + # A valid param name means we have an optional/non-positional argument: + # The separator is an empty string in case the param_name + # has an equal sign as we have to honour it. + # If the parameter doesn't end with an equal sign then a + # space character is injected to split the parameter name + # and its value + param_name = param.name + separator = "" if param.name.endswith("=") else " " + + return "{param_name}{separator}{param_value}".format( + param_name=param_name, + separator=separator, + param_value=param_value, + ) + + if param.name is None: + raise ConfigurationException( + "Missing user parameter with alias '{}' for command '{}'.".format( + index_or_alias, cmd_name + ) + ) + + return param.name # flag: just return the parameter name + + def resolve_commands_and_params( + self, backend_type: str, cmd_name: str, return_params: bool, index_or_alias: str + ) -> str: + """Resolve command or command's param value.""" + if backend_type == "system": + backend = cast(Backend, self.ctx.system) + backend_params = self.ctx.system_params + else: # Application or Tool backend + backend = cast(Backend, self.ctx.app) + backend_params = self.ctx.app_params + + if cmd_name not in backend.commands: + raise ConfigurationException("Command {} not found".format(cmd_name)) + + if return_params: + params = backend.resolved_parameters(cmd_name, backend_params) + if index_or_alias.isnumeric(): + i = int(index_or_alias) + if i not in range(len(params)): + raise ConfigurationException( + "Invalid parameter index {} for command {}".format(i, cmd_name) + ) + + param_value = params[i][0] + else: + param_value = None + for value, param in params: + if param.alias == index_or_alias: + param_value = value + break + + if not param_value: + raise ConfigurationException( + ( + "No value for parameter with index or alias {} of command {}" + ).format(index_or_alias, cmd_name) + ) + return param_value + + if not index_or_alias.isnumeric(): + raise ConfigurationException("Bad command index {}".format(index_or_alias)) + + i = int(index_or_alias) + commands = backend.build_command(cmd_name, backend_params, self.param_resolver) + if i not in range(len(commands)): + raise ConfigurationException( + "Invalid index {} for command {}".format(i, cmd_name) + ) + + return commands[i] + + def resolve_variables(self, backend_type: str, var_name: str) -> str: + """Resolve variable value.""" + if backend_type == "system": + backend = cast(Backend, self.ctx.system) + else: # Application or Tool backend + backend = cast(Backend, self.ctx.app) + + if var_name not in backend.variables: + raise ConfigurationException("Unknown variable {}".format(var_name)) + + return backend.variables[var_name] + + def param_matcher( + self, + param_name: str, + cmd_name: Optional[str], + resolved_params: Optional[List[Tuple[Optional[str], Param]]], + ) -> str: + """Regexp to resolve a param from the param_name.""" + # this pattern supports parameter names like "application.commands.run:0" and + # "system.commands.run.params:0" + # Note: 'software' is included for backward compatibility. + commands_and_params_match = re.match( + r"(?Papplication|software|tool|system)[.]commands[.]" + r"(?P\w+)" + r"(?P[.]params|)[:]" + r"(?P\w+)", + param_name, + ) + + if commands_and_params_match: + backend_type, cmd_name, return_params, index_or_alias = ( + commands_and_params_match["type"], + commands_and_params_match["name"], + commands_and_params_match["params"], + commands_and_params_match["index_or_alias"], + ) + return self.resolve_commands_and_params( + backend_type, cmd_name, bool(return_params), index_or_alias + ) + + # Note: 'software' is included for backward compatibility. + variables_match = re.match( + r"(?Papplication|software|tool|system)[.]variables:(?P\w+)", + param_name, + ) + if variables_match: + backend_type, var_name = ( + variables_match["type"], + variables_match["var_name"], + ) + return self.resolve_variables(backend_type, var_name) + + user_params_match = re.match(r"user_params:(?P\w+)", param_name) + if user_params_match: + index_or_alias = user_params_match["index_or_alias"] + return self.resolve_user_params(cmd_name, index_or_alias, resolved_params) + + raise ConfigurationException( + "Unable to resolve parameter {}".format(param_name) + ) + + def param_resolver( + self, + param_name: str, + cmd_name: Optional[str] = None, + resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, + ) -> str: + """Resolve parameter value based on current execution context.""" + # Note: 'software.*' is included for backward compatibility. + resolved_param = None + if param_name in ["application.name", "tool.name", "software.name"]: + resolved_param = self.ctx.app.name + elif param_name in [ + "application.description", + "tool.description", + "software.description", + ]: + resolved_param = self.ctx.app.description + elif self.ctx.app.config_location and ( + param_name + in ["application.config_dir", "tool.config_dir", "software.config_dir"] + ): + resolved_param = str(self.ctx.app.config_location.absolute()) + elif self.ctx.app.build_dir and ( + param_name + in ["application.build_dir", "tool.build_dir", "software.build_dir"] + ): + resolved_param = str(self.ctx.build_dir().absolute()) + elif self.ctx.system is not None: + if param_name == "system.name": + resolved_param = self.ctx.system.name + elif param_name == "system.description": + resolved_param = self.ctx.system.description + elif param_name == "system.config_dir" and self.ctx.system.config_location: + resolved_param = str(self.ctx.system.config_location.absolute()) + + if not resolved_param: + resolved_param = self.param_matcher(param_name, cmd_name, resolved_params) + return resolved_param + + def __call__( + self, + param_name: str, + cmd_name: Optional[str] = None, + resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, + ) -> str: + """Resolve provided parameter.""" + return self.param_resolver(param_name, cmd_name, resolved_params) + + +class Reporter: + """Report metrics from the simulation output.""" + + def __init__(self, parsers: Optional[List[OutputParser]] = None) -> None: + """Create an empty reporter (i.e. no parsers registered).""" + self.parsers: List[OutputParser] = parsers if parsers is not None else [] + self._report: Dict[str, Any] = defaultdict(lambda: defaultdict(dict)) + + def parse(self, output: bytearray) -> None: + """Parse output and append parsed metrics to internal report dict.""" + for parser in self.parsers: + # Merge metrics from different parsers (do not overwrite) + self._report[parser.name]["metrics"].update(parser(output)) + + def get_filtered_output(self, output: bytearray) -> bytearray: + """Filter the output according to each parser.""" + for parser in self.parsers: + output = parser.filter_out_parsed_content(output) + return output + + def report(self, ctx: ExecutionContext) -> Dict[str, Any]: + """Add static simulation info to parsed data and return the report.""" + report: Dict[str, Any] = defaultdict(dict) + # Add static simulation info + report.update(self._static_info(ctx)) + # Add metrics parsed from the output + for key, val in self._report.items(): + report[key].update(val) + return report + + @staticmethod + def save(report: Dict[str, Any], report_file: Path) -> None: + """Save the report to a JSON file.""" + with open(report_file, "w", encoding="utf-8") as file: + json.dump(report, file, indent=4) + + @staticmethod + def _compute_all_params(cli_params: List[str], backend: Backend) -> Dict[str, str]: + """ + Build a dict of all parameters, {name:value}. + + Param values taken from command line if specified, defaults otherwise. + """ + # map of params passed from the cli ["p1=v1","p2=v2"] -> {"p1":"v1", "p2":"v2"} + app_params_map = dict(parse_raw_parameter(expr) for expr in cli_params) + + # a map of params declared in the application, with values taken from the CLI, + # defaults otherwise + all_params = { + (p.alias or p.name): app_params_map.get( + cast(str, p.name), cast(str, p.default_value) + ) + for cmd in backend.commands.values() + for p in cmd.params + } + return cast(Dict[str, str], all_params) + + @staticmethod + def _static_info(ctx: ExecutionContext) -> Dict[str, Any]: + """Extract static simulation information from the context.""" + if ctx.system is None: + raise ValueError("No system available to report.") + + info = { + "system": { + "name": ctx.system.name, + "params": Reporter._compute_all_params(ctx.system_params, ctx.system), + }, + "application": { + "name": ctx.app.name, + "params": Reporter._compute_all_params(ctx.app_params, ctx.app), + }, + } + return info + + +def validate_parameters( + backend: Backend, command_names: List[str], params: List[str] +) -> None: + """Check parameters passed to backend.""" + for param in params: + acceptable = any( + backend.validate_parameter(command_name, param) + for command_name in command_names + if command_name in backend.commands + ) + + if not acceptable: + backend_type = "System" if isinstance(backend, System) else "Application" + raise ValueError( + "{} parameter '{}' not valid for command '{}'".format( + backend_type, param, " or ".join(command_names) + ) + ) + + +def get_application_by_name_and_system( + application_name: str, system_name: str +) -> Application: + """Get application.""" + applications = get_application(application_name, system_name) + if not applications: + raise ValueError( + "Application '{}' doesn't support the system '{}'".format( + application_name, system_name + ) + ) + + if len(applications) != 1: + raise ValueError( + "Error during getting application {} for the system {}".format( + application_name, system_name + ) + ) + + return applications[0] + + +def get_application_and_system( + application_name: str, system_name: str +) -> Tuple[Application, System]: + """Return application and system by provided names.""" + system = get_system(system_name) + if not system: + raise ValueError("System {} is not found".format(system_name)) + + application = get_application_by_name_and_system(application_name, system_name) + + return application, system + + +# pylint: disable=too-many-arguments +def run_application( + application_name: str, + application_params: List[str], + system_name: str, + system_params: List[str], + custom_deploy_data: List[DataPaths], + report_file: Optional[Path] = None, +) -> ExecutionContext: + """Run application on the provided system.""" + application, system = get_application_and_system(application_name, system_name) + validate_parameters(application, ["build", "run"], application_params) + validate_parameters(system, ["build", "run"], system_params) + + execution_params = ExecutionParams() + if isinstance(system, StandaloneSystem): + execution_params["disable_locking"] = True + execution_params["unique_build_dir"] = True + + ctx = ExecutionContext( + app=application, + app_params=application_params, + system=system, + system_params=system_params, + custom_deploy_data=custom_deploy_data, + execution_params=execution_params, + report_file=report_file, + ) + + with build_dir_manager(ctx): + if ctx.is_build_required: + execute_application_command_build(ctx) + + execute_application_command_run(ctx) + + return ctx + + +def execute_application_command_build(ctx: ExecutionContext) -> None: + """Execute application command 'build'.""" + with ExitStack() as context_stack: + for manager in get_context_managers("build", ctx): + context_stack.enter_context(manager(ctx)) + + build_dir = ctx.build_dir() + recreate_directory(build_dir) + + build_commands = ctx.app.build_command( + "build", ctx.app_params, ctx.param_resolver + ) + execute_commands_locally(build_commands, build_dir) + + +def execute_commands_locally(commands: List[str], cwd: Path) -> None: + """Execute list of commands locally.""" + for command in commands: + print("Running: {}".format(command)) + run_and_wait( + command, cwd, terminate_on_error=True, out=sys.stdout, err=sys.stderr + ) + + +def execute_application_command_run(ctx: ExecutionContext) -> None: + """Execute application command.""" + assert ctx.system is not None, "System must be provided." + if ctx.is_deploy_needed and not ctx.system.supports_deploy: + raise ConfigurationException( + "System {} does not support data deploy".format(ctx.system.name) + ) + + with ExitStack() as context_stack: + for manager in get_context_managers("run", ctx): + context_stack.enter_context(manager(ctx)) + + print("Generating commands to execute") + commands_to_run = build_run_commands(ctx) + + if ctx.system.connectable: + establish_connection(ctx) + + if ctx.system.supports_deploy: + deploy_data(ctx) + + for command in commands_to_run: + print("Running: {}".format(command)) + exit_code, ctx.stdout, ctx.stderr = ctx.system.run(command) + + if exit_code != 0: + print("Application exited with exit code {}".format(exit_code)) + + if ctx.reporter: + ctx.reporter.parse(ctx.stdout) + ctx.stdout = ctx.reporter.get_filtered_output(ctx.stdout) + + if ctx.reporter: + report = ctx.reporter.report(ctx) + ctx.reporter.save(report, cast(Path, ctx.report_file)) + + +def establish_connection( + ctx: ExecutionContext, retries: int = 90, interval: float = 15.0 +) -> None: + """Establish connection with the system.""" + assert ctx.system is not None, "System is required." + host, port = ctx.system.connection_details() + print( + "Trying to establish connection with '{}:{}' - " + "{} retries every {} seconds ".format(host, port, retries, interval), + end="", + ) + + try: + for _ in range(retries): + print(".", end="", flush=True) + + if ctx.system.establish_connection(): + break + + if isinstance(ctx.system, ControlledSystem) and not ctx.system.is_running(): + print( + "\n\n---------- {} execution failed ----------".format( + ctx.system.name + ) + ) + stdout, stderr = ctx.system.get_output() + print(stdout) + print(stderr) + + raise Exception("System is not running") + + wait(interval) + else: + raise ConnectionException("Couldn't connect to '{}:{}'.".format(host, port)) + finally: + print() + + +def wait(interval: float) -> None: + """Wait for a period of time.""" + time.sleep(interval) + + +def deploy_data(ctx: ExecutionContext) -> None: + """Deploy data to the system.""" + assert ctx.system is not None, "System is required." + for item in itertools.chain(ctx.app.get_deploy_data(), ctx.custom_deploy_data): + print("Deploying {} onto {}".format(item.src, item.dst)) + ctx.system.deploy(item.src, item.dst) + + +def build_run_commands(ctx: ExecutionContext) -> List[str]: + """Build commands to run application.""" + if isinstance(ctx.system, StandaloneSystem): + return ctx.system.build_command("run", ctx.system_params, ctx.param_resolver) + + return ctx.app.build_command("run", ctx.app_params, ctx.param_resolver) + + +@contextmanager +def controlled_system_manager(ctx: ExecutionContext) -> Generator[None, None, None]: + """Context manager used for system initialisation before run.""" + system = cast(ControlledSystem, ctx.system) + commands = system.build_command("run", ctx.system_params, ctx.param_resolver) + pid_file_path: Optional[Path] = None + if ctx.is_locking_required: + file_lock_path = get_file_lock_path(ctx) + pid_file_path = file_lock_path.parent / "{}.pid".format(file_lock_path.stem) + + system.start(commands, ctx.is_locking_required, pid_file_path) + try: + yield + finally: + print("Shutting down sequence...") + print("Stopping {}... (It could take few seconds)".format(system.name)) + system.stop(wait=True) + print("{} stopped successfully.".format(system.name)) + + +@contextmanager +def lock_execution_manager(ctx: ExecutionContext) -> Generator[None, None, None]: + """Lock execution manager.""" + file_lock_path = get_file_lock_path(ctx) + file_lock = FileLock(str(file_lock_path)) + + try: + file_lock.acquire(timeout=1) + except Timeout as error: + raise AnotherInstanceIsRunningException() from error + + try: + yield + finally: + file_lock.release() + + +def get_file_lock_path(ctx: ExecutionContext, lock_dir: Path = Path("/tmp")) -> Path: + """Get file lock path.""" + lock_modules = [] + if ctx.app.lock: + lock_modules.append(ctx.app.name) + if ctx.system is not None and ctx.system.lock: + lock_modules.append(ctx.system.name) + lock_filename = "" + if lock_modules: + lock_filename = "_".join(["middleware"] + lock_modules) + ".lock" + + if lock_filename: + lock_filename = resolve_all_parameters(lock_filename, ctx.param_resolver) + lock_filename = valid_for_filename(lock_filename) + + if not lock_filename: + raise ConfigurationException("No filename for lock provided") + + if not isinstance(lock_dir, Path) or not lock_dir.is_dir(): + raise ConfigurationException( + "Invalid directory {} for lock files provided".format(lock_dir) + ) + + return lock_dir / lock_filename + + +@contextmanager +def build_dir_manager(ctx: ExecutionContext) -> Generator[None, None, None]: + """Build directory manager.""" + try: + yield + finally: + if ( + ctx.is_build_required + and ctx.is_unique_build_dir_required + and ctx.build_dir().is_dir() + ): + remove_directory(ctx.build_dir()) + + +def get_context_managers( + command_name: str, ctx: ExecutionContext +) -> Sequence[Callable[[ExecutionContext], ContextManager[None]]]: + """Get context manager for the system.""" + managers = [] + + if ctx.is_locking_required: + managers.append(lock_execution_manager) + + if command_name == "run": + if isinstance(ctx.system, ControlledSystem): + managers.append(controlled_system_manager) + + return managers diff --git a/src/mlia/backend/fs.py b/src/mlia/backend/fs.py new file mode 100644 index 0000000..9979fcb --- /dev/null +++ b/src/mlia/backend/fs.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module to host all file system related functions.""" +import re +import shutil +from pathlib import Path +from typing import Any +from typing import Literal +from typing import Optional + +from mlia.utils.filesystem import get_mlia_resources + +ResourceType = Literal["applications", "systems"] + + +def get_backend_resources() -> Path: + """Get backend resources folder path.""" + return get_mlia_resources() / "backends" + + +def get_backends_path(name: ResourceType) -> Path: + """Return the absolute path of the specified resource. + + It uses importlib to return resources packaged with MANIFEST.in. + """ + if not name: + raise ResourceWarning("Resource name is not provided") + + resource_path = get_backend_resources() / name + if resource_path.is_dir(): + return resource_path + + raise ResourceWarning("Resource '{}' not found.".format(name)) + + +def copy_directory_content(source: Path, destination: Path) -> None: + """Copy content of the source directory into destination directory.""" + for item in source.iterdir(): + src = source / item.name + dest = destination / item.name + + if src.is_dir(): + shutil.copytree(src, dest) + else: + shutil.copy2(src, dest) + + +def remove_resource(resource_directory: str, resource_type: ResourceType) -> None: + """Remove resource data.""" + resources = get_backends_path(resource_type) + + resource_location = resources / resource_directory + if not resource_location.exists(): + raise Exception("Resource {} does not exist".format(resource_directory)) + + if not resource_location.is_dir(): + raise Exception("Wrong resource {}".format(resource_directory)) + + shutil.rmtree(resource_location) + + +def remove_directory(directory_path: Optional[Path]) -> None: + """Remove directory.""" + if not directory_path or not directory_path.is_dir(): + raise Exception("No directory path provided") + + shutil.rmtree(directory_path) + + +def recreate_directory(directory_path: Optional[Path]) -> None: + """Recreate directory.""" + if not directory_path: + raise Exception("No directory path provided") + + if directory_path.exists() and not directory_path.is_dir(): + raise Exception( + "Path {} does exist and it is not a directory".format(str(directory_path)) + ) + + if directory_path.is_dir(): + remove_directory(directory_path) + + directory_path.mkdir() + + +def read_file(file_path: Path, mode: Optional[str] = None) -> Any: + """Read file as string or bytearray.""" + if file_path.is_file(): + if mode is not None: + # Ignore pylint warning because mode can be 'binary' as well which + # is not compatible with specifying encodings. + with open(file_path, mode) as file: # pylint: disable=unspecified-encoding + return file.read() + else: + with open(file_path, encoding="utf-8") as file: + return file.read() + + if mode == "rb": + return b"" + return "" + + +def read_file_as_string(file_path: Path) -> str: + """Read file as string.""" + return str(read_file(file_path)) + + +def read_file_as_bytearray(file_path: Path) -> bytearray: + """Read a file as bytearray.""" + return bytearray(read_file(file_path, mode="rb")) + + +def valid_for_filename(value: str, replacement: str = "") -> str: + """Replace non alpha numeric characters.""" + return re.sub(r"[^\w.]", replacement, value, flags=re.ASCII) diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py new file mode 100644 index 0000000..3a1016c --- /dev/null +++ b/src/mlia/backend/manager.py @@ -0,0 +1,447 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for backend integration.""" +import logging +import re +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import Tuple + +from mlia.backend.application import get_available_applications +from mlia.backend.application import install_application +from mlia.backend.common import DataPaths +from mlia.backend.execution import ExecutionContext +from mlia.backend.execution import run_application +from mlia.backend.system import get_available_systems +from mlia.backend.system import install_system +from mlia.utils.proc import OutputConsumer +from mlia.utils.proc import RunningCommand + + +logger = logging.getLogger(__name__) + +# Mapping backend -> device_type -> system_name +_SUPPORTED_SYSTEMS = { + "Corstone-300": { + "ethos-u55": "Corstone-300: Cortex-M55+Ethos-U55", + "ethos-u65": "Corstone-300: Cortex-M55+Ethos-U65", + }, + "Corstone-310": { + "ethos-u55": "Corstone-310: Cortex-M85+Ethos-U55", + }, +} + +# Mapping system_name -> memory_mode -> application +_SYSTEM_TO_APP_MAP = { + "Corstone-300: Cortex-M55+Ethos-U55": { + "Sram": "Generic Inference Runner: Ethos-U55 SRAM", + "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + }, + "Corstone-300: Cortex-M55+Ethos-U65": { + "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + "Dedicated_Sram": "Generic Inference Runner: Ethos-U65 Dedicated SRAM", + }, + "Corstone-310: Cortex-M85+Ethos-U55": { + "Sram": "Generic Inference Runner: Ethos-U55 SRAM", + "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + }, +} + + +def get_system_name(backend: str, device_type: str) -> str: + """Get the system name for the given backend and device type.""" + return _SUPPORTED_SYSTEMS[backend][device_type] + + +def is_supported(backend: str, device_type: Optional[str] = None) -> bool: + """Check if the backend (and optionally device type) is supported.""" + if device_type is None: + return backend in _SUPPORTED_SYSTEMS + + try: + get_system_name(backend, device_type) + return True + except KeyError: + return False + + +def supported_backends() -> List[str]: + """Get a list of all backends supported by the backend manager.""" + return list(_SUPPORTED_SYSTEMS.keys()) + + +def get_all_system_names(backend: str) -> List[str]: + """Get all systems supported by the backend.""" + return list(_SUPPORTED_SYSTEMS.get(backend, {}).values()) + + +def get_all_application_names(backend: str) -> List[str]: + """Get all applications supported by the backend.""" + app_set = { + app + for sys in get_all_system_names(backend) + for app in _SYSTEM_TO_APP_MAP[sys].values() + } + return list(app_set) + + +@dataclass +class DeviceInfo: + """Device information.""" + + device_type: Literal["ethos-u55", "ethos-u65"] + mac: int + memory_mode: Literal["Sram", "Shared_Sram", "Dedicated_Sram"] + + +@dataclass +class ModelInfo: + """Model info.""" + + model_path: Path + + +@dataclass +class PerformanceMetrics: + """Performance metrics parsed from generic inference output.""" + + npu_active_cycles: int + npu_idle_cycles: int + npu_total_cycles: int + npu_axi0_rd_data_beat_received: int + npu_axi0_wr_data_beat_written: int + npu_axi1_rd_data_beat_received: int + + +@dataclass +class ExecutionParams: + """Application execution params.""" + + application: str + system: str + application_params: List[str] + system_params: List[str] + deploy_params: List[str] + + +class LogWriter(OutputConsumer): + """Redirect output to the logger.""" + + def feed(self, line: str) -> None: + """Process line from the output.""" + logger.debug(line.strip()) + + +class GenericInferenceOutputParser(OutputConsumer): + """Generic inference app output parser.""" + + PATTERNS = { + name: tuple(re.compile(pattern, re.IGNORECASE) for pattern in patterns) + for name, patterns in ( + ( + "npu_active_cycles", + ( + r"NPU ACTIVE cycles: (?P\d+)", + r"NPU ACTIVE: (?P\d+) cycles", + ), + ), + ( + "npu_idle_cycles", + ( + r"NPU IDLE cycles: (?P\d+)", + r"NPU IDLE: (?P\d+) cycles", + ), + ), + ( + "npu_total_cycles", + ( + r"NPU TOTAL cycles: (?P\d+)", + r"NPU TOTAL: (?P\d+) cycles", + ), + ), + ( + "npu_axi0_rd_data_beat_received", + ( + r"NPU AXI0_RD_DATA_BEAT_RECEIVED beats: (?P\d+)", + r"NPU AXI0_RD_DATA_BEAT_RECEIVED: (?P\d+) beats", + ), + ), + ( + "npu_axi0_wr_data_beat_written", + ( + r"NPU AXI0_WR_DATA_BEAT_WRITTEN beats: (?P\d+)", + r"NPU AXI0_WR_DATA_BEAT_WRITTEN: (?P\d+) beats", + ), + ), + ( + "npu_axi1_rd_data_beat_received", + ( + r"NPU AXI1_RD_DATA_BEAT_RECEIVED beats: (?P\d+)", + r"NPU AXI1_RD_DATA_BEAT_RECEIVED: (?P\d+) beats", + ), + ), + ) + } + + def __init__(self) -> None: + """Init generic inference output parser instance.""" + self.result: Dict = {} + + def feed(self, line: str) -> None: + """Feed new line to the parser.""" + for name, patterns in self.PATTERNS.items(): + for pattern in patterns: + match = pattern.search(line) + + if match: + self.result[name] = int(match["value"]) + return + + def is_ready(self) -> bool: + """Return true if all expected data has been parsed.""" + return self.result.keys() == self.PATTERNS.keys() + + def missed_keys(self) -> List[str]: + """Return list of the keys that have not been found in the output.""" + return sorted(self.PATTERNS.keys() - self.result.keys()) + + +class BackendRunner: + """Backend runner.""" + + def __init__(self) -> None: + """Init BackendRunner instance.""" + + @staticmethod + def get_installed_systems() -> List[str]: + """Get list of the installed systems.""" + return [system.name for system in get_available_systems()] + + @staticmethod + def get_installed_applications(system: Optional[str] = None) -> List[str]: + """Get list of the installed application.""" + return [ + app.name + for app in get_available_applications() + if system is None or app.can_run_on(system) + ] + + def is_application_installed(self, application: str, system: str) -> bool: + """Return true if requested application installed.""" + return application in self.get_installed_applications(system) + + def is_system_installed(self, system: str) -> bool: + """Return true if requested system installed.""" + return system in self.get_installed_systems() + + def systems_installed(self, systems: List[str]) -> bool: + """Check if all provided systems are installed.""" + if not systems: + return False + + installed_systems = self.get_installed_systems() + return all(system in installed_systems for system in systems) + + def applications_installed(self, applications: List[str]) -> bool: + """Check if all provided applications are installed.""" + if not applications: + return False + + installed_apps = self.get_installed_applications() + return all(app in installed_apps for app in applications) + + def all_installed(self, systems: List[str], apps: List[str]) -> bool: + """Check if all provided artifacts are installed.""" + return self.systems_installed(systems) and self.applications_installed(apps) + + @staticmethod + def install_system(system_path: Path) -> None: + """Install system.""" + install_system(system_path) + + @staticmethod + def install_application(app_path: Path) -> None: + """Install application.""" + install_application(app_path) + + @staticmethod + def run_application(execution_params: ExecutionParams) -> ExecutionContext: + """Run requested application.""" + + def to_data_paths(paths: str) -> DataPaths: + """Split input into two and create new DataPaths object.""" + src, dst = paths.split(sep=":", maxsplit=1) + return DataPaths(Path(src), dst) + + deploy_data_paths = [ + to_data_paths(paths) for paths in execution_params.deploy_params + ] + + ctx = run_application( + execution_params.application, + execution_params.application_params, + execution_params.system, + execution_params.system_params, + deploy_data_paths, + ) + + return ctx + + @staticmethod + def _params(name: str, params: List[str]) -> List[str]: + return [p for item in [(name, param) for param in params] for p in item] + + +class GenericInferenceRunner(ABC): + """Abstract class for generic inference runner.""" + + def __init__(self, backend_runner: BackendRunner): + """Init generic inference runner instance.""" + self.backend_runner = backend_runner + self.running_inference: Optional[RunningCommand] = None + + def run( + self, model_info: ModelInfo, output_consumers: List[OutputConsumer] + ) -> None: + """Run generic inference for the provided device/model.""" + execution_params = self.get_execution_params(model_info) + + ctx = self.backend_runner.run_application(execution_params) + if ctx.stdout is not None: + self.consume_output(ctx.stdout, output_consumers) + + def stop(self) -> None: + """Stop running inference.""" + if self.running_inference is None: + return + + self.running_inference.stop() + + @abstractmethod + def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams: + """Get execution params for the provided model.""" + + def __enter__(self) -> "GenericInferenceRunner": + """Enter context.""" + return self + + def __exit__(self, *_args: Any) -> None: + """Exit context.""" + self.stop() + + def check_system_and_application(self, system_name: str, app_name: str) -> None: + """Check if requested system and application installed.""" + if not self.backend_runner.is_system_installed(system_name): + raise Exception(f"System {system_name} is not installed") + + if not self.backend_runner.is_application_installed(app_name, system_name): + raise Exception( + f"Application {app_name} for the system {system_name} " + "is not installed" + ) + + @staticmethod + def consume_output(output: bytearray, consumers: List[OutputConsumer]) -> None: + """Pass program's output to the consumers.""" + for line in output.decode("utf8").splitlines(): + for consumer in consumers: + consumer.feed(line) + + +class GenericInferenceRunnerEthosU(GenericInferenceRunner): + """Generic inference runner on U55/65.""" + + def __init__( + self, backend_runner: BackendRunner, device_info: DeviceInfo, backend: str + ) -> None: + """Init generic inference runner instance.""" + super().__init__(backend_runner) + + system_name, app_name = self.resolve_system_and_app(device_info, backend) + self.system_name = system_name + self.app_name = app_name + self.device_info = device_info + + @staticmethod + def resolve_system_and_app( + device_info: DeviceInfo, backend: str + ) -> Tuple[str, str]: + """Find appropriate system and application for the provided device/backend.""" + try: + system_name = get_system_name(backend, device_info.device_type) + except KeyError as ex: + raise RuntimeError( + f"Unsupported device {device_info.device_type} " + f"for backend {backend}" + ) from ex + + if system_name not in _SYSTEM_TO_APP_MAP: + raise RuntimeError(f"System {system_name} is not installed") + + try: + app_name = _SYSTEM_TO_APP_MAP[system_name][device_info.memory_mode] + except KeyError as err: + raise RuntimeError( + f"Unsupported memory mode {device_info.memory_mode}" + ) from err + + return system_name, app_name + + def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams: + """Get execution params for Ethos-U55/65.""" + self.check_system_and_application(self.system_name, self.app_name) + + system_params = [ + f"mac={self.device_info.mac}", + f"input_file={model_info.model_path.absolute()}", + ] + + return ExecutionParams( + self.app_name, + self.system_name, + [], + system_params, + [], + ) + + +def get_generic_runner(device_info: DeviceInfo, backend: str) -> GenericInferenceRunner: + """Get generic runner for provided device and backend.""" + backend_runner = get_backend_runner() + return GenericInferenceRunnerEthosU(backend_runner, device_info, backend) + + +def estimate_performance( + model_info: ModelInfo, device_info: DeviceInfo, backend: str +) -> PerformanceMetrics: + """Get performance estimations.""" + with get_generic_runner(device_info, backend) as generic_runner: + output_parser = GenericInferenceOutputParser() + output_consumers = [output_parser, LogWriter()] + + generic_runner.run(model_info, output_consumers) + + if not output_parser.is_ready(): + missed_data = ",".join(output_parser.missed_keys()) + logger.debug( + "Unable to get performance metrics, missed data %s", missed_data + ) + raise Exception("Unable to get performance metrics, insufficient data") + + return PerformanceMetrics(**output_parser.result) + + +def get_backend_runner() -> BackendRunner: + """ + Return BackendRunner instance. + + Note: This is needed for the unit tests. + """ + return BackendRunner() diff --git a/src/mlia/backend/output_parser.py b/src/mlia/backend/output_parser.py new file mode 100644 index 0000000..111772a --- /dev/null +++ b/src/mlia/backend/output_parser.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Definition of output parsers (including base class OutputParser).""" +import base64 +import json +import re +from abc import ABC +from abc import abstractmethod +from typing import Any +from typing import Dict +from typing import Union + + +class OutputParser(ABC): + """Abstract base class for output parsers.""" + + def __init__(self, name: str) -> None: + """Set up the name of the parser.""" + super().__init__() + self.name = name + + @abstractmethod + def __call__(self, output: bytearray) -> Dict[str, Any]: + """Parse the output and return a map of names to metrics.""" + return {} + + # pylint: disable=no-self-use + def filter_out_parsed_content(self, output: bytearray) -> bytearray: + """ + Filter out the parsed content from the output. + + Does nothing by default. Can be overridden in subclasses. + """ + return output + + +class RegexOutputParser(OutputParser): + """Parser of standard output data using regular expressions.""" + + _TYPE_MAP = {"str": str, "float": float, "int": int} + + def __init__( + self, + name: str, + regex_config: Dict[str, Dict[str, str]], + ) -> None: + """ + Set up the parser with the regular expressions. + + The regex_config is mapping from a name to a dict with keys 'pattern' + and 'type': + - The 'pattern' holds the regular expression that must contain exactly + one capturing parenthesis + - The 'type' can be one of ['str', 'float', 'int']. + + Example: + ``` + {"Metric1": {"pattern": ".*= *(.*)", "type": "str"}} + ``` + + The different regular expressions from the config are combined using + non-capturing parenthesis, i.e. regular expressions must not overlap + if more than one match per line is expected. + """ + super().__init__(name) + + self._verify_config(regex_config) + self._regex_cfg = regex_config + + # Compile regular expression to match in the output + self._regex = re.compile( + "|".join("(?:{0})".format(x["pattern"]) for x in self._regex_cfg.values()) + ) + + def __call__(self, output: bytearray) -> Dict[str, Union[str, float, int]]: + """ + Parse the output and return a map of names to metrics. + + Example: + Assuming a regex_config as used as example in `__init__()` and the + following output: + ``` + Simulation finished: + SIMULATION_STATUS = SUCCESS + Simulation DONE + ``` + Then calling the parser should return the following dict: + ``` + { + "Metric1": "SUCCESS" + } + ``` + """ + metrics = {} + output_str = output.decode("utf-8") + results = self._regex.findall(output_str) + for line_result in results: + for idx, (name, cfg) in enumerate(self._regex_cfg.items()): + # The result(s) returned by findall() are either a single string + # or a tuple (depending on the number of groups etc.) + result = ( + line_result if isinstance(line_result, str) else line_result[idx] + ) + if result: + mapped_result = self._TYPE_MAP[cfg["type"]](result) + metrics[name] = mapped_result + return metrics + + def _verify_config(self, regex_config: Dict[str, Dict[str, str]]) -> None: + """Make sure we have a valid regex_config. + + I.e. + - Exactly one capturing parenthesis per pattern + - Correct types + """ + for name, cfg in regex_config.items(): + # Check that there is one capturing group defined in the pattern. + regex = re.compile(cfg["pattern"]) + if regex.groups != 1: + raise ValueError( + f"Pattern for metric '{name}' must have exactly one " + f"capturing parenthesis, but it has {regex.groups}." + ) + # Check if type is supported + if not cfg["type"] in self._TYPE_MAP: + raise TypeError( + f"Type '{cfg['type']}' for metric '{name}' is not " + f"supported. Choose from: {list(self._TYPE_MAP.keys())}." + ) + + +class Base64OutputParser(OutputParser): + """ + Parser to extract base64-encoded JSON from tagged standard output. + + Example of the tagged output: + ``` + # Encoded JSON: {"test": 1234} + eyJ0ZXN0IjogMTIzNH0 + ``` + """ + + TAG_NAME = "metrics" + + def __init__(self, name: str) -> None: + """Set up the regular expression to extract tagged strings.""" + super().__init__(name) + self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)") + + def __call__(self, output: bytearray) -> Dict[str, Any]: + """ + Parse the output and return a map of index (as string) to decoded JSON. + + Example: + Using the tagged output from the class docs the parser should return + the following dict: + ``` + { + "0": {"test": 1234} + } + ``` + """ + metrics = {} + output_str = output.decode("utf-8") + results = self._regex.findall(output_str) + for idx, result_base64 in enumerate(results): + result_json = base64.b64decode(result_base64, validate=True) + result = json.loads(result_json) + metrics[str(idx)] = result + + return metrics + + def filter_out_parsed_content(self, output: bytearray) -> bytearray: + """Filter out base64-encoded content from the output.""" + output_str = self._regex.sub("", output.decode("utf-8")) + return bytearray(output_str.encode("utf-8")) diff --git a/src/mlia/backend/proc.py b/src/mlia/backend/proc.py new file mode 100644 index 0000000..90ff414 --- /dev/null +++ b/src/mlia/backend/proc.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Processes module. + +This module contains all classes and functions for dealing with Linux +processes. +""" +import csv +import datetime +import logging +import shlex +import signal +import time +from pathlib import Path +from typing import Any +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Tuple + +import psutil +from sh import Command +from sh import CommandNotFound +from sh import ErrorReturnCode +from sh import RunningCommand + +from mlia.backend.fs import valid_for_filename + + +class CommandFailedException(Exception): + """Exception for failed command execution.""" + + +class ShellCommand: + """Wrapper class for shell commands.""" + + def __init__(self, base_log_path: str = "/tmp") -> None: + """Initialise the class. + + base_log_path: it is the base directory where logs will be stored + """ + self.base_log_path = base_log_path + + def run( + self, + cmd: str, + *args: str, + _cwd: Optional[Path] = None, + _tee: bool = True, + _bg: bool = True, + _out: Any = None, + _err: Any = None, + _search_paths: Optional[List[Path]] = None + ) -> RunningCommand: + """Run the shell command with the given arguments. + + There are special arguments that modify the behaviour of the process. + _cwd: current working directory + _tee: it redirects the stdout both to console and file + _bg: if True, it runs the process in background and the command is not + blocking. + _out: use this object for stdout redirect, + _err: use this object for stderr redirect, + _search_paths: If presented used for searching executable + """ + try: + kwargs = {} + if _cwd: + kwargs["_cwd"] = str(_cwd) + command = Command(cmd, _search_paths).bake(args, **kwargs) + except CommandNotFound as error: + logging.error("Command '%s' not found", error.args[0]) + raise error + + out, err = _out, _err + if not _out and not _err: + out, err = [ + str(item) + for item in self.get_stdout_stderr_paths(self.base_log_path, cmd) + ] + + return command(_out=out, _err=err, _tee=_tee, _bg=_bg, _bg_exc=False) + + @classmethod + def get_stdout_stderr_paths(cls, base_log_path: str, cmd: str) -> Tuple[Path, Path]: + """Construct and returns the paths of stdout/stderr files.""" + timestamp = datetime.datetime.now().timestamp() + base_path = Path(base_log_path) + base = base_path / "{}_{}".format(valid_for_filename(cmd, "_"), timestamp) + stdout = base.with_suffix(".out") + stderr = base.with_suffix(".err") + try: + stdout.touch() + stderr.touch() + except FileNotFoundError as error: + logging.error("File not found: %s", error.filename) + raise error + return stdout, stderr + + +def parse_command(command: str, shell: str = "bash") -> List[str]: + """Parse command.""" + cmd, *args = shlex.split(command, posix=True) + + if is_shell_script(cmd): + args = [cmd] + args + cmd = shell + + return [cmd] + args + + +def get_stdout_stderr_paths( + command: str, base_log_path: str = "/tmp" +) -> Tuple[Path, Path]: + """Construct and returns the paths of stdout/stderr files.""" + cmd, *_ = parse_command(command) + + return ShellCommand.get_stdout_stderr_paths(base_log_path, cmd) + + +def execute_command( # pylint: disable=invalid-name + command: str, + cwd: Path, + bg: bool = False, + shell: str = "bash", + out: Any = None, + err: Any = None, +) -> RunningCommand: + """Execute shell command.""" + cmd, *args = parse_command(command, shell) + + search_paths = None + if cmd != shell and (cwd / cmd).is_file(): + search_paths = [cwd] + + return ShellCommand().run( + cmd, *args, _cwd=cwd, _bg=bg, _search_paths=search_paths, _out=out, _err=err + ) + + +def is_shell_script(cmd: str) -> bool: + """Check if command is shell script.""" + return cmd.endswith(".sh") + + +def run_and_wait( + command: str, + cwd: Path, + terminate_on_error: bool = False, + out: Any = None, + err: Any = None, +) -> Tuple[int, bytearray, bytearray]: + """ + Run command and wait while it is executing. + + Returns a tuple: (exit_code, stdout, stderr) + """ + running_cmd: Optional[RunningCommand] = None + try: + running_cmd = execute_command(command, cwd, bg=True, out=out, err=err) + return running_cmd.exit_code, running_cmd.stdout, running_cmd.stderr + except ErrorReturnCode as cmd_failed: + raise CommandFailedException() from cmd_failed + except Exception as error: + is_running = running_cmd is not None and running_cmd.is_alive() + if terminate_on_error and is_running: + print("Terminating ...") + terminate_command(running_cmd) + + raise error + + +def terminate_command( + running_cmd: RunningCommand, + wait: bool = True, + wait_period: float = 0.5, + number_of_attempts: int = 20, +) -> None: + """Terminate running command.""" + try: + running_cmd.process.signal_group(signal.SIGINT) + if wait: + for _ in range(number_of_attempts): + time.sleep(wait_period) + if not running_cmd.is_alive(): + return + print( + "Unable to terminate process {}. Sending SIGTERM...".format( + running_cmd.process.pid + ) + ) + running_cmd.process.signal_group(signal.SIGTERM) + except ProcessLookupError: + pass + + +def terminate_external_process( + process: psutil.Process, + wait_period: float = 0.5, + number_of_attempts: int = 20, + wait_for_termination: float = 5.0, +) -> None: + """Terminate external process.""" + try: + process.terminate() + for _ in range(number_of_attempts): + if not process.is_running(): + return + time.sleep(wait_period) + + if process.is_running(): + process.terminate() + time.sleep(wait_for_termination) + except psutil.Error: + print("Unable to terminate process") + + +class ProcessInfo(NamedTuple): + """Process information.""" + + name: str + executable: str + cwd: str + pid: int + + +def save_process_info(pid: int, pid_file: Path) -> None: + """Save process information to file.""" + try: + parent = psutil.Process(pid) + children = parent.children(recursive=True) + family = [parent] + children + + with open(pid_file, "w", encoding="utf-8") as file: + csv_writer = csv.writer(file) + for member in family: + process_info = ProcessInfo( + name=member.name(), + executable=member.exe(), + cwd=member.cwd(), + pid=member.pid, + ) + csv_writer.writerow(process_info) + except psutil.NoSuchProcess: + # if process does not exist or finishes before + # function call then nothing could be saved + # just ignore this exception and exit + pass + + +def read_process_info(pid_file: Path) -> List[ProcessInfo]: + """Read information about previous system processes.""" + if not pid_file.is_file(): + return [] + + result = [] + with open(pid_file, encoding="utf-8") as file: + csv_reader = csv.reader(file) + for row in csv_reader: + name, executable, cwd, pid = row + result.append( + ProcessInfo(name=name, executable=executable, cwd=cwd, pid=int(pid)) + ) + + return result + + +def print_command_stdout(command: RunningCommand) -> None: + """Print the stdout of a command. + + The command has 2 states: running and done. + If the command is running, the output is taken by the running process. + If the command has ended its execution, the stdout is taken from stdout + property + """ + if command.is_alive(): + while True: + try: + print(command.next(), end="") + except StopIteration: + break + else: + print(command.stdout) diff --git a/src/mlia/backend/protocol.py b/src/mlia/backend/protocol.py new file mode 100644 index 0000000..ebfe69a --- /dev/null +++ b/src/mlia/backend/protocol.py @@ -0,0 +1,325 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain protocol related classes and functions.""" +from abc import ABC +from abc import abstractmethod +from contextlib import closing +from pathlib import Path +from typing import Any +from typing import cast +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import Union + +import paramiko + +from mlia.backend.common import ConfigurationException +from mlia.backend.config import LocalProtocolConfig +from mlia.backend.config import SSHConfig +from mlia.backend.proc import run_and_wait + + +# Redirect all paramiko thread exceptions to a file otherwise these will be +# printed to stderr. +paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO) + + +class SSHConnectionException(Exception): + """SSH connection exception.""" + + +class SupportsClose(ABC): + """Class indicates support of close operation.""" + + @abstractmethod + def close(self) -> None: + """Close protocol session.""" + + +class SupportsDeploy(ABC): + """Class indicates support of deploy operation.""" + + @abstractmethod + def deploy(self, src: Path, dst: str, retry: bool = True) -> None: + """Abstract method for deploy data.""" + + +class SupportsConnection(ABC): + """Class indicates that protocol uses network connections.""" + + @abstractmethod + def establish_connection(self) -> bool: + """Establish connection with underlying system.""" + + @abstractmethod + def connection_details(self) -> Tuple[str, int]: + """Return connection details (host, port).""" + + +class Protocol(ABC): + """Abstract class for representing the protocol.""" + + def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: + """Initialize the class using a dict.""" + self.__dict__.update(iterable, **kwargs) + self._validate() + + @abstractmethod + def _validate(self) -> None: + """Abstract method for config data validation.""" + + @abstractmethod + def run( + self, command: str, retry: bool = False + ) -> Tuple[int, bytearray, bytearray]: + """ + Abstract method for running commands. + + Returns a tuple: (exit_code, stdout, stderr) + """ + + +class CustomSFTPClient(paramiko.SFTPClient): + """Class for creating a custom sftp client.""" + + def put_dir(self, source: Path, target: str) -> None: + """Upload the source directory to the target path. + + The target directory needs to exists and the last directory of the + source will be created under the target with all its content. + """ + # Create the target directory + self._mkdir(target, ignore_existing=True) + # Create the last directory in the source on the target + self._mkdir("{}/{}".format(target, source.name), ignore_existing=True) + # Go through the whole content of source + for item in sorted(source.glob("**/*")): + relative_path = item.relative_to(source.parent) + remote_target = target / relative_path + if item.is_file(): + self.put(str(item), str(remote_target)) + else: + self._mkdir(str(remote_target), ignore_existing=True) + + def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None: + """Extend mkdir functionality. + + This version adds an option to not fail if the folder exists. + """ + try: + super().mkdir(path, mode) + except IOError as error: + if ignore_existing: + pass + else: + raise error + + +class LocalProtocol(Protocol): + """Class for local protocol.""" + + protocol: str + cwd: Path + + def run( + self, command: str, retry: bool = False + ) -> Tuple[int, bytearray, bytearray]: + """ + Run command locally. + + Returns a tuple: (exit_code, stdout, stderr) + """ + if not isinstance(self.cwd, Path) or not self.cwd.is_dir(): + raise ConfigurationException("Wrong working directory {}".format(self.cwd)) + + stdout = bytearray() + stderr = bytearray() + + return run_and_wait( + command, self.cwd, terminate_on_error=True, out=stdout, err=stderr + ) + + def _validate(self) -> None: + """Validate protocol configuration.""" + assert hasattr(self, "protocol") and self.protocol == "local" + assert hasattr(self, "cwd") + + +class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection): + """Class for SSH protocol.""" + + protocol: str + username: str + password: str + hostname: str + port: int + + def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: + """Initialize the class using a dict.""" + super().__init__(iterable, **kwargs) + # Internal state to store if the system is connectable. It will be set + # to true at the first connection instance + self.client: Optional[paramiko.client.SSHClient] = None + self.port = int(self.port) + + def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: + """ + Run command over SSH. + + Returns a tuple: (exit_code, stdout, stderr) + """ + transport = self._get_transport() + with closing(transport.open_session()) as channel: + # Enable shell's .profile settings and execute command + channel.exec_command("bash -l -c '{}'".format(command)) + exit_status = -1 + stdout = bytearray() + stderr = bytearray() + while True: + if channel.exit_status_ready(): + exit_status = channel.recv_exit_status() + # Call it one last time to read any leftover in the channel + self._recv_stdout_err(channel, stdout, stderr) + break + self._recv_stdout_err(channel, stdout, stderr) + + return exit_status, stdout, stderr + + def deploy(self, src: Path, dst: str, retry: bool = True) -> None: + """Deploy src to remote dst over SSH. + + src and dst should be path to a file or directory. + """ + transport = self._get_transport() + sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport)) + + with closing(sftp): + if src.is_dir(): + sftp.put_dir(src, dst) + elif src.is_file(): + sftp.put(str(src), dst) + else: + raise Exception("Deploy error: file type not supported") + + # After the deployment of files, sync the remote filesystem to flush + # buffers to hard disk + self.run("sync") + + def close(self) -> None: + """Close protocol session.""" + if self.client is not None: + print("Try syncing remote file system...") + # Before stopping the system, we try to run sync to make sure all + # data are flushed on disk. + self.run("sync", retry=False) + self._close_client(self.client) + + def establish_connection(self) -> bool: + """Establish connection with underlying system.""" + if self.client is not None: + return True + + self.client = self._connect() + return self.client is not None + + def _get_transport(self) -> paramiko.transport.Transport: + """Get transport.""" + self.establish_connection() + + if self.client is None: + raise SSHConnectionException( + "Couldn't connect to '{}:{}'.".format(self.hostname, self.port) + ) + + transport = self.client.get_transport() + if not transport: + raise Exception("Unable to get transport") + + return transport + + def connection_details(self) -> Tuple[str, int]: + """Return connection details of underlying system.""" + return (self.hostname, self.port) + + def _connect(self) -> Optional[paramiko.client.SSHClient]: + """Try to establish connection.""" + client: Optional[paramiko.client.SSHClient] = None + try: + client = paramiko.client.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect( + self.hostname, + self.port, + self.username, + self.password, + # next parameters should be set to False to disable authentication + # using ssh keys + allow_agent=False, + look_for_keys=False, + ) + return client + except ( + # OSError raised on first attempt to connect when running inside Docker + OSError, + paramiko.ssh_exception.NoValidConnectionsError, + paramiko.ssh_exception.SSHException, + ): + # even if connection is not established socket could be still + # open, it should be closed + self._close_client(client) + + return None + + @staticmethod + def _close_client(client: Optional[paramiko.client.SSHClient]) -> None: + """Close ssh client.""" + try: + if client is not None: + client.close() + except Exception: # pylint: disable=broad-except + pass + + @classmethod + def _recv_stdout_err( + cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray + ) -> None: + """Read from channel to stdout/stder.""" + chunk_size = 512 + if channel.recv_ready(): + stdout_chunk = channel.recv(chunk_size) + stdout.extend(stdout_chunk) + if channel.recv_stderr_ready(): + stderr_chunk = channel.recv_stderr(chunk_size) + stderr.extend(stderr_chunk) + + def _validate(self) -> None: + """Check if there are all the info for establishing the connection.""" + assert hasattr(self, "protocol") and self.protocol == "ssh" + assert hasattr(self, "username") + assert hasattr(self, "password") + assert hasattr(self, "hostname") + assert hasattr(self, "port") + + +class ProtocolFactory: + """Factory class to return the appropriate Protocol class.""" + + @staticmethod + def get_protocol( + config: Optional[Union[SSHConfig, LocalProtocolConfig]], + **kwargs: Union[str, Path, None] + ) -> Union[SSHProtocol, LocalProtocol]: + """Return the right protocol instance based on the config.""" + if not config: + raise ValueError("No protocol config provided") + + protocol = config["protocol"] + if protocol == "ssh": + return SSHProtocol(config) + + if protocol == "local": + cwd = kwargs.get("cwd") + return LocalProtocol(config, cwd=cwd) + + raise ValueError("Protocol not supported: '{}'".format(protocol)) diff --git a/src/mlia/backend/source.py b/src/mlia/backend/source.py new file mode 100644 index 0000000..dcf6835 --- /dev/null +++ b/src/mlia/backend/source.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain source related classes and functions.""" +import os +import shutil +import tarfile +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from tarfile import TarFile +from typing import Optional +from typing import Union + +from mlia.backend.common import BACKEND_CONFIG_FILE +from mlia.backend.common import ConfigurationException +from mlia.backend.common import get_backend_config +from mlia.backend.common import is_backend_directory +from mlia.backend.common import load_config +from mlia.backend.config import BackendConfig +from mlia.backend.fs import copy_directory_content + + +class Source(ABC): + """Source class.""" + + @abstractmethod + def name(self) -> Optional[str]: + """Get source name.""" + + @abstractmethod + def config(self) -> Optional[BackendConfig]: + """Get configuration file content.""" + + @abstractmethod + def install_into(self, destination: Path) -> None: + """Install source into destination directory.""" + + @abstractmethod + def create_destination(self) -> bool: + """Return True if destination folder should be created before installation.""" + + +class DirectorySource(Source): + """DirectorySource class.""" + + def __init__(self, directory_path: Path) -> None: + """Create the DirectorySource instance.""" + assert isinstance(directory_path, Path) + self.directory_path = directory_path + + def name(self) -> str: + """Return name of source.""" + return self.directory_path.name + + def config(self) -> Optional[BackendConfig]: + """Return configuration file content.""" + if not is_backend_directory(self.directory_path): + raise ConfigurationException("No configuration file found") + + config_file = get_backend_config(self.directory_path) + return load_config(config_file) + + def install_into(self, destination: Path) -> None: + """Install source into destination directory.""" + if not destination.is_dir(): + raise ConfigurationException("Wrong destination {}".format(destination)) + + if not self.directory_path.is_dir(): + raise ConfigurationException( + "Directory {} does not exist".format(self.directory_path) + ) + + copy_directory_content(self.directory_path, destination) + + def create_destination(self) -> bool: + """Return True if destination folder should be created before installation.""" + return True + + +class TarArchiveSource(Source): + """TarArchiveSource class.""" + + def __init__(self, archive_path: Path) -> None: + """Create the TarArchiveSource class.""" + assert isinstance(archive_path, Path) + self.archive_path = archive_path + self._config: Optional[BackendConfig] = None + self._has_top_level_folder: Optional[bool] = None + self._name: Optional[str] = None + + def _read_archive_content(self) -> None: + """Read various information about archive.""" + # get source name from archive name (everything without extensions) + extensions = "".join(self.archive_path.suffixes) + self._name = self.archive_path.name.rstrip(extensions) + + if not self.archive_path.exists(): + return + + with self._open(self.archive_path) as archive: + try: + config_entry = archive.getmember(BACKEND_CONFIG_FILE) + self._has_top_level_folder = False + except KeyError as error_no_config: + try: + archive_entries = archive.getnames() + entries_common_prefix = os.path.commonprefix(archive_entries) + top_level_dir = entries_common_prefix.rstrip("/") + + if not top_level_dir: + raise RuntimeError( + "Archive has no top level directory" + ) from error_no_config + + config_path = "{}/{}".format(top_level_dir, BACKEND_CONFIG_FILE) + + config_entry = archive.getmember(config_path) + self._has_top_level_folder = True + self._name = top_level_dir + except (KeyError, RuntimeError) as error_no_root_dir_or_config: + raise ConfigurationException( + "No configuration file found" + ) from error_no_root_dir_or_config + + content = archive.extractfile(config_entry) + self._config = load_config(content) + + def config(self) -> Optional[BackendConfig]: + """Return configuration file content.""" + if self._config is None: + self._read_archive_content() + + return self._config + + def name(self) -> Optional[str]: + """Return name of the source.""" + if self._name is None: + self._read_archive_content() + + return self._name + + def create_destination(self) -> bool: + """Return True if destination folder must be created before installation.""" + if self._has_top_level_folder is None: + self._read_archive_content() + + return not self._has_top_level_folder + + def install_into(self, destination: Path) -> None: + """Install source into destination directory.""" + if not destination.is_dir(): + raise ConfigurationException("Wrong destination {}".format(destination)) + + with self._open(self.archive_path) as archive: + archive.extractall(destination) + + def _open(self, archive_path: Path) -> TarFile: + """Open archive file.""" + if not archive_path.is_file(): + raise ConfigurationException("File {} does not exist".format(archive_path)) + + if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"): + mode = "r:gz" + else: + raise ConfigurationException( + "Unsupported archive type {}".format(archive_path) + ) + + # The returned TarFile object can be used as a context manager (using + # 'with') by the calling instance. + return tarfile.open( # pylint: disable=consider-using-with + self.archive_path, mode=mode + ) + + +def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]: + """Return appropriate source instance based on provided source path.""" + if source_path.is_file(): + return TarArchiveSource(source_path) + + if source_path.is_dir(): + return DirectorySource(source_path) + + raise ConfigurationException("Unable to read {}".format(source_path)) + + +def create_destination_and_install(source: Source, resource_path: Path) -> None: + """Create destination directory and install source. + + This function is used for actual installation of system/backend New + directory will be created inside :resource_path: if needed If for example + archive contains top level folder then no need to create new directory + """ + destination = resource_path + create_destination = source.create_destination() + + if create_destination: + name = source.name() + if not name: + raise ConfigurationException("Unable to get source name") + + destination = resource_path / name + destination.mkdir() + try: + source.install_into(destination) + except Exception as error: + if create_destination: + shutil.rmtree(destination) + raise error diff --git a/src/mlia/backend/system.py b/src/mlia/backend/system.py new file mode 100644 index 0000000..469083e --- /dev/null +++ b/src/mlia/backend/system.py @@ -0,0 +1,289 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""System backend module.""" +from pathlib import Path +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from mlia.backend.common import Backend +from mlia.backend.common import ConfigurationException +from mlia.backend.common import get_backend_configs +from mlia.backend.common import get_backend_directories +from mlia.backend.common import load_config +from mlia.backend.common import remove_backend +from mlia.backend.config import SystemConfig +from mlia.backend.controller import SystemController +from mlia.backend.controller import SystemControllerSingleInstance +from mlia.backend.fs import get_backends_path +from mlia.backend.protocol import ProtocolFactory +from mlia.backend.protocol import SupportsClose +from mlia.backend.protocol import SupportsConnection +from mlia.backend.protocol import SupportsDeploy +from mlia.backend.source import create_destination_and_install +from mlia.backend.source import get_source + + +def get_available_systems_directory_names() -> List[str]: + """Return a list of directory names for all avialable systems.""" + return [entry.name for entry in get_backend_directories("systems")] + + +def get_available_systems() -> List["System"]: + """Return a list with all available systems.""" + available_systems = [] + for config_json in get_backend_configs("systems"): + config_entries = cast(List[SystemConfig], (load_config(config_json))) + for config_entry in config_entries: + config_entry["config_location"] = config_json.parent.absolute() + system = load_system(config_entry) + available_systems.append(system) + + return sorted(available_systems, key=lambda system: system.name) + + +def get_system(system_name: str) -> Optional["System"]: + """Return a system instance with the same name passed as argument.""" + available_systems = get_available_systems() + for system in available_systems: + if system_name == system.name: + return system + return None + + +def install_system(source_path: Path) -> None: + """Install new system.""" + try: + source = get_source(source_path) + config = cast(List[SystemConfig], source.config()) + systems_to_install = [load_system(entry) for entry in config] + except Exception as error: + raise ConfigurationException("Unable to read system definition") from error + + if not systems_to_install: + raise ConfigurationException("No system definition found") + + available_systems = get_available_systems() + already_installed = [s for s in systems_to_install if s in available_systems] + if already_installed: + names = [system.name for system in already_installed] + raise ConfigurationException( + "Systems [{}] are already installed".format(",".join(names)) + ) + + create_destination_and_install(source, get_backends_path("systems")) + + +def remove_system(directory_name: str) -> None: + """Remove system.""" + remove_backend(directory_name, "systems") + + +class System(Backend): + """System class.""" + + def __init__(self, config: SystemConfig) -> None: + """Construct the System class using the dictionary passed.""" + super().__init__(config) + + self._setup_data_transfer(config) + self._setup_reporting(config) + + def _setup_data_transfer(self, config: SystemConfig) -> None: + data_transfer_config = config.get("data_transfer") + protocol = ProtocolFactory().get_protocol( + data_transfer_config, cwd=self.config_location + ) + self.protocol = protocol + + def _setup_reporting(self, config: SystemConfig) -> None: + self.reporting = config.get("reporting") + + def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: + """ + Run command on the system. + + Returns a tuple: (exit_code, stdout, stderr) + """ + return self.protocol.run(command, retry) + + def deploy(self, src: Path, dst: str, retry: bool = True) -> None: + """Deploy files to the system.""" + if isinstance(self.protocol, SupportsDeploy): + self.protocol.deploy(src, dst, retry) + + @property + def supports_deploy(self) -> bool: + """Check if protocol supports deploy operation.""" + return isinstance(self.protocol, SupportsDeploy) + + @property + def connectable(self) -> bool: + """Check if protocol supports connection.""" + return isinstance(self.protocol, SupportsConnection) + + def establish_connection(self) -> bool: + """Establish connection with the system.""" + if not isinstance(self.protocol, SupportsConnection): + raise ConfigurationException( + "System {} does not support connections".format(self.name) + ) + + return self.protocol.establish_connection() + + def connection_details(self) -> Tuple[str, int]: + """Return connection details.""" + if not isinstance(self.protocol, SupportsConnection): + raise ConfigurationException( + "System {} does not support connections".format(self.name) + ) + + return self.protocol.connection_details() + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, System): + return False + + return super().__eq__(other) and self.name == other.name + + def get_details(self) -> Dict[str, Any]: + """Return a dictionary with all relevant information of a System.""" + output = { + "type": "system", + "name": self.name, + "description": self.description, + "data_transfer_protocol": self.protocol.protocol, + "commands": self._get_command_details(), + "annotations": self.annotations, + } + + return output + + +class StandaloneSystem(System): + """StandaloneSystem class.""" + + +def get_controller( + single_instance: bool, pid_file_path: Optional[Path] = None +) -> SystemController: + """Get system controller.""" + if single_instance: + return SystemControllerSingleInstance(pid_file_path) + + return SystemController() + + +class ControlledSystem(System): + """ControlledSystem class.""" + + def __init__(self, config: SystemConfig): + """Construct the ControlledSystem class using the dictionary passed.""" + super().__init__(config) + self.controller: Optional[SystemController] = None + + def start( + self, + commands: List[str], + single_instance: bool = True, + pid_file_path: Optional[Path] = None, + ) -> None: + """Launch the system.""" + if ( + not isinstance(self.config_location, Path) + or not self.config_location.is_dir() + ): + raise ConfigurationException( + "System {} has wrong config location".format(self.name) + ) + + self.controller = get_controller(single_instance, pid_file_path) + self.controller.start(commands, self.config_location) + + def is_running(self) -> bool: + """Check if system is running.""" + if not self.controller: + return False + + return self.controller.is_running() + + def get_output(self) -> Tuple[str, str]: + """Return system output.""" + if not self.controller: + return "", "" + + return self.controller.get_output() + + def stop(self, wait: bool = False) -> None: + """Stop the system.""" + if not self.controller: + raise Exception("System has not been started") + + if isinstance(self.protocol, SupportsClose): + try: + self.protocol.close() + except Exception as error: # pylint: disable=broad-except + print(error) + self.controller.stop(wait) + + +def load_system(config: SystemConfig) -> Union[StandaloneSystem, ControlledSystem]: + """Load system based on it's execution type.""" + data_transfer = config.get("data_transfer", {}) + protocol = data_transfer.get("protocol") + populate_shared_params(config) + + if protocol == "ssh": + return ControlledSystem(config) + + if protocol == "local": + return StandaloneSystem(config) + + raise ConfigurationException( + "Unsupported execution type for protocol {}".format(protocol) + ) + + +def populate_shared_params(config: SystemConfig) -> None: + """Populate command parameters with shared parameters.""" + user_params = config.get("user_params") + if not user_params or "shared" not in user_params: + return + + shared_user_params = user_params["shared"] + if not shared_user_params: + return + + only_aliases = all(p.get("alias") for p in shared_user_params) + if not only_aliases: + raise ConfigurationException("All shared parameters should have aliases") + + commands = config.get("commands", {}) + for cmd_name in ["build", "run"]: + command = commands.get(cmd_name) + if command is None: + commands[cmd_name] = [] + cmd_user_params = user_params.get(cmd_name) + if not cmd_user_params: + cmd_user_params = shared_user_params + else: + only_aliases = all(p.get("alias") for p in cmd_user_params) + if not only_aliases: + raise ConfigurationException( + "All parameters for command {} should have aliases".format(cmd_name) + ) + merged_by_alias = { + **{p.get("alias"): p for p in shared_user_params}, + **{p.get("alias"): p for p in cmd_user_params}, + } + cmd_user_params = list(merged_by_alias.values()) + + user_params[cmd_name] = cmd_user_params + + config["commands"] = commands + del user_params["shared"] diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py index 838b051..a673230 100644 --- a/src/mlia/cli/config.py +++ b/src/mlia/cli/config.py @@ -5,7 +5,7 @@ import logging from functools import lru_cache from typing import List -import mlia.tools.aiet_wrapper as aiet +import mlia.backend.manager as backend_manager from mlia.tools.metadata.common import DefaultInstallationManager from mlia.tools.metadata.common import InstallationManager from mlia.tools.metadata.corstone import get_corstone_installations @@ -25,12 +25,12 @@ def get_available_backends() -> List[str]: """Return list of the available backends.""" available_backends = ["Vela"] - # Add backends using AIET + # Add backends using backend manager manager = get_installation_manager() available_backends.extend( ( backend - for backend in aiet.supported_backends() + for backend in backend_manager.supported_backends() if manager.backend_installed(backend) ) ) diff --git a/src/mlia/devices/ethosu/performance.py b/src/mlia/devices/ethosu/performance.py index b0718a5..a73045a 100644 --- a/src/mlia/devices/ethosu/performance.py +++ b/src/mlia/devices/ethosu/performance.py @@ -10,7 +10,7 @@ from typing import Optional from typing import Tuple from typing import Union -import mlia.tools.aiet_wrapper as aiet +import mlia.backend.manager as backend_manager import mlia.tools.vela_wrapper as vela from mlia.core.context import Context from mlia.core.performance import PerformanceEstimator @@ -147,15 +147,15 @@ class VelaPerformanceEstimator( return memory_usage -class AIETPerformanceEstimator( +class CorstonePerformanceEstimator( PerformanceEstimator[Union[Path, ModelConfiguration], NPUCycles] ): - """AIET based performance estimator.""" + """Corstone-based performance estimator.""" def __init__( self, context: Context, device: EthosUConfiguration, backend: str ) -> None: - """Init AIET based performance estimator.""" + """Init Corstone-based performance estimator.""" self.context = context self.device = device self.backend = backend @@ -179,24 +179,24 @@ class AIETPerformanceEstimator( model_path, self.device.compiler_options, optimized_model_path ) - model_info = aiet.ModelInfo(model_path=optimized_model_path) - device_info = aiet.DeviceInfo( + model_info = backend_manager.ModelInfo(model_path=optimized_model_path) + device_info = backend_manager.DeviceInfo( device_type=self.device.target, # type: ignore mac=self.device.mac, memory_mode=self.device.compiler_options.memory_mode, # type: ignore ) - aiet_perf_metrics = aiet.estimate_performance( + corstone_perf_metrics = backend_manager.estimate_performance( model_info, device_info, self.backend ) npu_cycles = NPUCycles( - aiet_perf_metrics.npu_active_cycles, - aiet_perf_metrics.npu_idle_cycles, - aiet_perf_metrics.npu_total_cycles, - aiet_perf_metrics.npu_axi0_rd_data_beat_received, - aiet_perf_metrics.npu_axi0_wr_data_beat_written, - aiet_perf_metrics.npu_axi1_rd_data_beat_received, + corstone_perf_metrics.npu_active_cycles, + corstone_perf_metrics.npu_idle_cycles, + corstone_perf_metrics.npu_total_cycles, + corstone_perf_metrics.npu_axi0_rd_data_beat_received, + corstone_perf_metrics.npu_axi0_wr_data_beat_written, + corstone_perf_metrics.npu_axi1_rd_data_beat_received, ) logger.info("Done\n") @@ -220,10 +220,11 @@ class EthosUPerformanceEstimator( if backends is None: backends = ["Vela"] # Only Vela is always available as default for backend in backends: - if backend != "Vela" and not aiet.is_supported(backend): + if backend != "Vela" and not backend_manager.is_supported(backend): raise ValueError( f"Unsupported backend '{backend}'. " - f"Only 'Vela' and {aiet.supported_backends()} are supported." + f"Only 'Vela' and {backend_manager.supported_backends()} " + "are supported." ) self.backends = set(backends) @@ -242,11 +243,11 @@ class EthosUPerformanceEstimator( if backend == "Vela": vela_estimator = VelaPerformanceEstimator(self.context, self.device) memory_usage = vela_estimator.estimate(tflite_model) - elif backend in aiet.supported_backends(): - aiet_estimator = AIETPerformanceEstimator( + elif backend in backend_manager.supported_backends(): + corstone_estimator = CorstonePerformanceEstimator( self.context, self.device, backend ) - npu_cycles = aiet_estimator.estimate(tflite_model) + npu_cycles = corstone_estimator.estimate(tflite_model) else: logger.warning( "Backend '%s' is not supported for Ethos-U performance " diff --git a/src/mlia/resources/aiet/applications/APPLICATIONS.txt b/src/mlia/resources/aiet/applications/APPLICATIONS.txt index 09127f8..a702e19 100644 --- a/src/mlia/resources/aiet/applications/APPLICATIONS.txt +++ b/src/mlia/resources/aiet/applications/APPLICATIONS.txt @@ -1,6 +1,7 @@ SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. SPDX-License-Identifier: Apache-2.0 -This directory contains the Generic Inference Runner application packages for AIET +This directory contains the application packages for the Generic Inference +Runner. -Each package should contain its own aiet-config.json file +Each package should contain its own aiet-config.json file. diff --git a/src/mlia/resources/aiet/systems/SYSTEMS.txt b/src/mlia/resources/aiet/systems/SYSTEMS.txt index bc27e73..3861769 100644 --- a/src/mlia/resources/aiet/systems/SYSTEMS.txt +++ b/src/mlia/resources/aiet/systems/SYSTEMS.txt @@ -1,8 +1,7 @@ SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. SPDX-License-Identifier: Apache-2.0 -This directory contains the configuration files of the systems for the AIET -middleware. +This directory contains the configuration files of the system backends. Supported systems: diff --git a/src/mlia/resources/backends/applications/.gitignore b/src/mlia/resources/backends/applications/.gitignore new file mode 100644 index 0000000..0226166 --- /dev/null +++ b/src/mlia/resources/backends/applications/.gitignore @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/src/mlia/resources/backends/systems/.gitignore b/src/mlia/resources/backends/systems/.gitignore new file mode 100644 index 0000000..0226166 --- /dev/null +++ b/src/mlia/resources/backends/systems/.gitignore @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/src/mlia/tools/aiet_wrapper.py b/src/mlia/tools/aiet_wrapper.py deleted file mode 100644 index 73e82ee..0000000 --- a/src/mlia/tools/aiet_wrapper.py +++ /dev/null @@ -1,435 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module for AIET integration.""" -import logging -import re -from abc import ABC -from abc import abstractmethod -from dataclasses import dataclass -from pathlib import Path -from typing import Any -from typing import Dict -from typing import List -from typing import Literal -from typing import Optional -from typing import Tuple - -from aiet.backend.application import get_available_applications -from aiet.backend.application import install_application -from aiet.backend.system import get_available_systems -from aiet.backend.system import install_system -from mlia.utils.proc import CommandExecutor -from mlia.utils.proc import OutputConsumer -from mlia.utils.proc import RunningCommand - - -logger = logging.getLogger(__name__) - -# Mapping backend -> device_type -> system_name -_SUPPORTED_SYSTEMS = { - "Corstone-300": { - "ethos-u55": "Corstone-300: Cortex-M55+Ethos-U55", - "ethos-u65": "Corstone-300: Cortex-M55+Ethos-U65", - }, - "Corstone-310": { - "ethos-u55": "Corstone-310: Cortex-M85+Ethos-U55", - }, -} - -# Mapping system_name -> memory_mode -> application -_SYSTEM_TO_APP_MAP = { - "Corstone-300: Cortex-M55+Ethos-U55": { - "Sram": "Generic Inference Runner: Ethos-U55 SRAM", - "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", - }, - "Corstone-300: Cortex-M55+Ethos-U65": { - "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", - "Dedicated_Sram": "Generic Inference Runner: Ethos-U65 Dedicated SRAM", - }, - "Corstone-310: Cortex-M85+Ethos-U55": { - "Sram": "Generic Inference Runner: Ethos-U55 SRAM", - "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", - }, -} - - -def get_system_name(backend: str, device_type: str) -> str: - """Get the AIET system name for the given backend and device type.""" - return _SUPPORTED_SYSTEMS[backend][device_type] - - -def is_supported(backend: str, device_type: Optional[str] = None) -> bool: - """Check if the backend (and optionally device type) is supported.""" - if device_type is None: - return backend in _SUPPORTED_SYSTEMS - - try: - get_system_name(backend, device_type) - return True - except KeyError: - return False - - -def supported_backends() -> List[str]: - """Get a list of all backends supported by the AIET wrapper.""" - return list(_SUPPORTED_SYSTEMS.keys()) - - -def get_all_system_names(backend: str) -> List[str]: - """Get all systems supported by the backend.""" - return list(_SUPPORTED_SYSTEMS.get(backend, {}).values()) - - -def get_all_application_names(backend: str) -> List[str]: - """Get all applications supported by the backend.""" - app_set = { - app - for sys in get_all_system_names(backend) - for app in _SYSTEM_TO_APP_MAP[sys].values() - } - return list(app_set) - - -@dataclass -class DeviceInfo: - """Device information.""" - - device_type: Literal["ethos-u55", "ethos-u65"] - mac: int - memory_mode: Literal["Sram", "Shared_Sram", "Dedicated_Sram"] - - -@dataclass -class ModelInfo: - """Model info.""" - - model_path: Path - - -@dataclass -class PerformanceMetrics: - """Performance metrics parsed from generic inference output.""" - - npu_active_cycles: int - npu_idle_cycles: int - npu_total_cycles: int - npu_axi0_rd_data_beat_received: int - npu_axi0_wr_data_beat_written: int - npu_axi1_rd_data_beat_received: int - - -@dataclass -class ExecutionParams: - """Application execution params.""" - - application: str - system: str - application_params: List[str] - system_params: List[str] - deploy_params: List[str] - - -class AIETLogWriter(OutputConsumer): - """Redirect AIET command output to the logger.""" - - def feed(self, line: str) -> None: - """Process line from the output.""" - logger.debug(line.strip()) - - -class GenericInferenceOutputParser(OutputConsumer): - """Generic inference app output parser.""" - - PATTERNS = { - name: tuple(re.compile(pattern, re.IGNORECASE) for pattern in patterns) - for name, patterns in ( - ( - "npu_active_cycles", - ( - r"NPU ACTIVE cycles: (?P\d+)", - r"NPU ACTIVE: (?P\d+) cycles", - ), - ), - ( - "npu_idle_cycles", - ( - r"NPU IDLE cycles: (?P\d+)", - r"NPU IDLE: (?P\d+) cycles", - ), - ), - ( - "npu_total_cycles", - ( - r"NPU TOTAL cycles: (?P\d+)", - r"NPU TOTAL: (?P\d+) cycles", - ), - ), - ( - "npu_axi0_rd_data_beat_received", - ( - r"NPU AXI0_RD_DATA_BEAT_RECEIVED beats: (?P\d+)", - r"NPU AXI0_RD_DATA_BEAT_RECEIVED: (?P\d+) beats", - ), - ), - ( - "npu_axi0_wr_data_beat_written", - ( - r"NPU AXI0_WR_DATA_BEAT_WRITTEN beats: (?P\d+)", - r"NPU AXI0_WR_DATA_BEAT_WRITTEN: (?P\d+) beats", - ), - ), - ( - "npu_axi1_rd_data_beat_received", - ( - r"NPU AXI1_RD_DATA_BEAT_RECEIVED beats: (?P\d+)", - r"NPU AXI1_RD_DATA_BEAT_RECEIVED: (?P\d+) beats", - ), - ), - ) - } - - def __init__(self) -> None: - """Init generic inference output parser instance.""" - self.result: Dict = {} - - def feed(self, line: str) -> None: - """Feed new line to the parser.""" - for name, patterns in self.PATTERNS.items(): - for pattern in patterns: - match = pattern.search(line) - - if match: - self.result[name] = int(match["value"]) - return - - def is_ready(self) -> bool: - """Return true if all expected data has been parsed.""" - return self.result.keys() == self.PATTERNS.keys() - - def missed_keys(self) -> List[str]: - """Return list of the keys that have not been found in the output.""" - return sorted(self.PATTERNS.keys() - self.result.keys()) - - -class AIETRunner: - """AIET runner.""" - - def __init__(self, executor: CommandExecutor) -> None: - """Init AIET runner instance.""" - self.executor = executor - - @staticmethod - def get_installed_systems() -> List[str]: - """Get list of the installed systems.""" - return [system.name for system in get_available_systems()] - - @staticmethod - def get_installed_applications(system: Optional[str] = None) -> List[str]: - """Get list of the installed application.""" - return [ - app.name - for app in get_available_applications() - if system is None or app.can_run_on(system) - ] - - def is_application_installed(self, application: str, system: str) -> bool: - """Return true if requested application installed.""" - return application in self.get_installed_applications(system) - - def is_system_installed(self, system: str) -> bool: - """Return true if requested system installed.""" - return system in self.get_installed_systems() - - def systems_installed(self, systems: List[str]) -> bool: - """Check if all provided systems are installed.""" - if not systems: - return False - - installed_systems = self.get_installed_systems() - return all(system in installed_systems for system in systems) - - def applications_installed(self, applications: List[str]) -> bool: - """Check if all provided applications are installed.""" - if not applications: - return False - - installed_apps = self.get_installed_applications() - return all(app in installed_apps for app in applications) - - def all_installed(self, systems: List[str], apps: List[str]) -> bool: - """Check if all provided artifacts are installed.""" - return self.systems_installed(systems) and self.applications_installed(apps) - - @staticmethod - def install_system(system_path: Path) -> None: - """Install system.""" - install_system(system_path) - - @staticmethod - def install_application(app_path: Path) -> None: - """Install application.""" - install_application(app_path) - - def run_application(self, execution_params: ExecutionParams) -> RunningCommand: - """Run requested application.""" - command = [ - "aiet", - "application", - "run", - "-n", - execution_params.application, - "-s", - execution_params.system, - *self._params("-p", execution_params.application_params), - *self._params("--system-param", execution_params.system_params), - *self._params("--deploy", execution_params.deploy_params), - ] - - return self._submit(command) - - @staticmethod - def _params(name: str, params: List[str]) -> List[str]: - return [p for item in [(name, param) for param in params] for p in item] - - def _submit(self, command: List[str]) -> RunningCommand: - """Submit command for the execution.""" - logger.debug("Submit command %s", " ".join(command)) - return self.executor.submit(command) - - -class GenericInferenceRunner(ABC): - """Abstract class for generic inference runner.""" - - def __init__(self, aiet_runner: AIETRunner): - """Init generic inference runner instance.""" - self.aiet_runner = aiet_runner - self.running_inference: Optional[RunningCommand] = None - - def run( - self, model_info: ModelInfo, output_consumers: List[OutputConsumer] - ) -> None: - """Run generic inference for the provided device/model.""" - execution_params = self.get_execution_params(model_info) - - self.running_inference = self.aiet_runner.run_application(execution_params) - self.running_inference.output_consumers = output_consumers - self.running_inference.consume_output() - - def stop(self) -> None: - """Stop running inference.""" - if self.running_inference is None: - return - - self.running_inference.stop() - - @abstractmethod - def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams: - """Get execution params for the provided model.""" - - def __enter__(self) -> "GenericInferenceRunner": - """Enter context.""" - return self - - def __exit__(self, *_args: Any) -> None: - """Exit context.""" - self.stop() - - def check_system_and_application(self, system_name: str, app_name: str) -> None: - """Check if requested system and application installed.""" - if not self.aiet_runner.is_system_installed(system_name): - raise Exception(f"System {system_name} is not installed") - - if not self.aiet_runner.is_application_installed(app_name, system_name): - raise Exception( - f"Application {app_name} for the system {system_name} " - "is not installed" - ) - - -class GenericInferenceRunnerEthosU(GenericInferenceRunner): - """Generic inference runner on U55/65.""" - - def __init__( - self, aiet_runner: AIETRunner, device_info: DeviceInfo, backend: str - ) -> None: - """Init generic inference runner instance.""" - super().__init__(aiet_runner) - - system_name, app_name = self.resolve_system_and_app(device_info, backend) - self.system_name = system_name - self.app_name = app_name - self.device_info = device_info - - @staticmethod - def resolve_system_and_app( - device_info: DeviceInfo, backend: str - ) -> Tuple[str, str]: - """Find appropriate system and application for the provided device/backend.""" - try: - system_name = get_system_name(backend, device_info.device_type) - except KeyError as ex: - raise RuntimeError( - f"Unsupported device {device_info.device_type} " - f"for backend {backend}" - ) from ex - - if system_name not in _SYSTEM_TO_APP_MAP: - raise RuntimeError(f"System {system_name} is not installed") - - try: - app_name = _SYSTEM_TO_APP_MAP[system_name][device_info.memory_mode] - except KeyError as err: - raise RuntimeError( - f"Unsupported memory mode {device_info.memory_mode}" - ) from err - - return system_name, app_name - - def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams: - """Get execution params for Ethos-U55/65.""" - self.check_system_and_application(self.system_name, self.app_name) - - system_params = [ - f"mac={self.device_info.mac}", - f"input_file={model_info.model_path.absolute()}", - ] - - return ExecutionParams( - self.app_name, - self.system_name, - [], - system_params, - [], - ) - - -def get_generic_runner(device_info: DeviceInfo, backend: str) -> GenericInferenceRunner: - """Get generic runner for provided device and backend.""" - aiet_runner = get_aiet_runner() - return GenericInferenceRunnerEthosU(aiet_runner, device_info, backend) - - -def estimate_performance( - model_info: ModelInfo, device_info: DeviceInfo, backend: str -) -> PerformanceMetrics: - """Get performance estimations.""" - with get_generic_runner(device_info, backend) as generic_runner: - output_parser = GenericInferenceOutputParser() - output_consumers = [output_parser, AIETLogWriter()] - - generic_runner.run(model_info, output_consumers) - - if not output_parser.is_ready(): - missed_data = ",".join(output_parser.missed_keys()) - logger.debug( - "Unable to get performance metrics, missed data %s", missed_data - ) - raise Exception("Unable to get performance metrics, insufficient data") - - return PerformanceMetrics(**output_parser.result) - - -def get_aiet_runner() -> AIETRunner: - """Return AIET runner.""" - executor = CommandExecutor() - return AIETRunner(executor) diff --git a/src/mlia/tools/metadata/corstone.py b/src/mlia/tools/metadata/corstone.py index 7a9d113..a92f81c 100644 --- a/src/mlia/tools/metadata/corstone.py +++ b/src/mlia/tools/metadata/corstone.py @@ -12,7 +12,8 @@ from typing import Iterable from typing import List from typing import Optional -import mlia.tools.aiet_wrapper as aiet +import mlia.backend.manager as backend_manager +from mlia.backend.fs import get_backend_resources from mlia.tools.metadata.common import DownloadAndInstall from mlia.tools.metadata.common import Installation from mlia.tools.metadata.common import InstallationType @@ -41,8 +42,8 @@ PathChecker = Callable[[Path], Optional[BackendInfo]] BackendInstaller = Callable[[bool, Path], Path] -class AIETMetadata: - """AIET installation metadata.""" +class BackendMetadata: + """Backend installation metadata.""" def __init__( self, @@ -55,7 +56,7 @@ class AIETMetadata: supported_platforms: Optional[List[str]] = None, ) -> None: """ - Initialize AIETMetaData. + Initialize BackendMetadata. Members expected_systems and expected_apps are filled automatically. """ @@ -67,15 +68,15 @@ class AIETMetadata: self.download_artifact = download_artifact self.supported_platforms = supported_platforms - self.expected_systems = aiet.get_all_system_names(name) - self.expected_apps = aiet.get_all_application_names(name) + self.expected_systems = backend_manager.get_all_system_names(name) + self.expected_apps = backend_manager.get_all_application_names(name) @property def expected_resources(self) -> Iterable[Path]: """Return list of expected resources.""" resources = [self.system_config, *self.apps_resources] - return (get_mlia_resources() / resource for resource in resources) + return (get_backend_resources() / resource for resource in resources) @property def supported_platform(self) -> bool: @@ -86,49 +87,49 @@ class AIETMetadata: return platform.system() in self.supported_platforms -class AIETBasedInstallation(Installation): - """Backend installation based on AIET functionality.""" +class BackendInstallation(Installation): + """Backend installation.""" def __init__( self, - aiet_runner: aiet.AIETRunner, - metadata: AIETMetadata, + backend_runner: backend_manager.BackendRunner, + metadata: BackendMetadata, path_checker: PathChecker, backend_installer: Optional[BackendInstaller], ) -> None: - """Init the tool installation.""" - self.aiet_runner = aiet_runner + """Init the backend installation.""" + self.backend_runner = backend_runner self.metadata = metadata self.path_checker = path_checker self.backend_installer = backend_installer @property def name(self) -> str: - """Return name of the tool.""" + """Return name of the backend.""" return self.metadata.name @property def description(self) -> str: - """Return description of the tool.""" + """Return description of the backend.""" return self.metadata.description @property def already_installed(self) -> bool: - """Return true if tool already installed.""" - return self.aiet_runner.all_installed( + """Return true if backend already installed.""" + return self.backend_runner.all_installed( self.metadata.expected_systems, self.metadata.expected_apps ) @property def could_be_installed(self) -> bool: - """Return true if tool could be installed.""" + """Return true if backend could be installed.""" if not self.metadata.supported_platform: return False return all_paths_valid(self.metadata.expected_resources) def supports(self, install_type: InstallationType) -> bool: - """Return true if tools supported type of the installation.""" + """Return true if backends supported type of the installation.""" if isinstance(install_type, DownloadAndInstall): return self.metadata.download_artifact is not None @@ -138,7 +139,7 @@ class AIETBasedInstallation(Installation): return False # type: ignore def install(self, install_type: InstallationType) -> None: - """Install the tool.""" + """Install the backend.""" if isinstance(install_type, DownloadAndInstall): download_artifact = self.metadata.download_artifact assert download_artifact is not None, "No artifact provided" @@ -153,7 +154,7 @@ class AIETBasedInstallation(Installation): raise Exception(f"Unable to install {install_type}") def install_from(self, backend_info: BackendInfo) -> None: - """Install tool from the directory.""" + """Install backend from the directory.""" mlia_resources = get_mlia_resources() with temp_directory() as tmpdir: @@ -169,15 +170,15 @@ class AIETBasedInstallation(Installation): copy_all(*resources_to_copy, dest=fvp_dist_dir) - self.aiet_runner.install_system(fvp_dist_dir) + self.backend_runner.install_system(fvp_dist_dir) for app in self.metadata.apps_resources: - self.aiet_runner.install_application(mlia_resources / app) + self.backend_runner.install_application(mlia_resources / app) def download_and_install( self, download_artifact: DownloadArtifact, eula_agrement: bool ) -> None: - """Download and install the tool.""" + """Download and install the backend.""" with temp_directory() as tmpdir: try: downloaded_to = download_artifact.download_to(tmpdir) @@ -307,10 +308,10 @@ class Corstone300Installer: def get_corstone_300_installation() -> Installation: """Get Corstone-300 installation.""" - corstone_300 = AIETBasedInstallation( - aiet_runner=aiet.get_aiet_runner(), + corstone_300 = BackendInstallation( + backend_runner=backend_manager.BackendRunner(), # pylint: disable=line-too-long - metadata=AIETMetadata( + metadata=BackendMetadata( name="Corstone-300", description="Corstone-300 FVP", system_config="aiet/systems/corstone-300/aiet-config.json", @@ -356,10 +357,10 @@ def get_corstone_300_installation() -> Installation: def get_corstone_310_installation() -> Installation: """Get Corstone-310 installation.""" - corstone_310 = AIETBasedInstallation( - aiet_runner=aiet.get_aiet_runner(), + corstone_310 = BackendInstallation( + backend_runner=backend_manager.BackendRunner(), # pylint: disable=line-too-long - metadata=AIETMetadata( + metadata=BackendMetadata( name="Corstone-310", description="Corstone-310 FVP", system_config="aiet/systems/corstone-310/aiet-config.json", diff --git a/src/mlia/utils/proc.py b/src/mlia/utils/proc.py index 39aca43..18a4305 100644 --- a/src/mlia/utils/proc.py +++ b/src/mlia/utils/proc.py @@ -8,7 +8,6 @@ import time from abc import ABC from abc import abstractmethod from contextlib import contextmanager -from contextlib import suppress from pathlib import Path from typing import Any from typing import Generator @@ -23,7 +22,7 @@ class OutputConsumer(ABC): @abstractmethod def feed(self, line: str) -> None: - """Feed new line to the consumerr.""" + """Feed new line to the consumer.""" class RunningCommand: @@ -32,7 +31,7 @@ class RunningCommand: def __init__(self, process: subprocess.Popen) -> None: """Init running command instance.""" self.process = process - self._output_consumers: Optional[List[OutputConsumer]] = None + self.output_consumers: List[OutputConsumer] = [] def is_alive(self) -> bool: """Return true if process is still alive.""" @@ -57,25 +56,14 @@ class RunningCommand: """Send signal to the process.""" self.process.send_signal(signal_num) - @property - def output_consumers(self) -> Optional[List[OutputConsumer]]: - """Property output_consumers.""" - return self._output_consumers - - @output_consumers.setter - def output_consumers(self, output_consumers: List[OutputConsumer]) -> None: - """Set output consumers.""" - self._output_consumers = output_consumers - def consume_output(self) -> None: """Pass program's output to the consumers.""" - if self.process is None or self.output_consumers is None: + if self.process is None or not self.output_consumers: return for line in self.stdout(): for consumer in self.output_consumers: - with suppress(): - consumer.feed(line) + consumer.feed(line) def stop( self, wait: bool = True, num_of_attempts: int = 5, interval: float = 0.5 -- cgit v1.2.1