diff options
Diffstat (limited to 'src/aiet')
30 files changed, 0 insertions, 4550 deletions
diff --git a/src/aiet/__init__.py b/src/aiet/__init__.py deleted file mode 100644 index 49304b1..0000000 --- a/src/aiet/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Init of aiet.""" -import pkg_resources - - -__version__ = pkg_resources.get_distribution("mlia").version diff --git a/src/aiet/backend/__init__.py b/src/aiet/backend/__init__.py deleted file mode 100644 index 3d60372..0000000 --- a/src/aiet/backend/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Backend module.""" diff --git a/src/aiet/backend/application.py b/src/aiet/backend/application.py deleted file mode 100644 index f6ef815..0000000 --- a/src/aiet/backend/application.py +++ /dev/null @@ -1,187 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Application backend module.""" -import re -from pathlib import Path -from typing import Any -from typing import cast -from typing import Dict -from typing import List -from typing import Optional - -from aiet.backend.common import Backend -from aiet.backend.common import ConfigurationException -from aiet.backend.common import DataPaths -from aiet.backend.common import get_backend_configs -from aiet.backend.common import get_backend_directories -from aiet.backend.common import load_application_or_tool_configs -from aiet.backend.common import load_config -from aiet.backend.common import remove_backend -from aiet.backend.config import ApplicationConfig -from aiet.backend.config import ExtendedApplicationConfig -from aiet.backend.source import create_destination_and_install -from aiet.backend.source import get_source -from aiet.utils.fs import get_resources - - -def get_available_application_directory_names() -> List[str]: - """Return a list of directory names for all available applications.""" - return [entry.name for entry in get_backend_directories("applications")] - - -def get_available_applications() -> List["Application"]: - """Return a list with all available applications.""" - available_applications = [] - for config_json in get_backend_configs("applications"): - config_entries = cast(List[ExtendedApplicationConfig], load_config(config_json)) - for config_entry in config_entries: - config_entry["config_location"] = config_json.parent.absolute() - applications = load_applications(config_entry) - available_applications += applications - - return sorted(available_applications, key=lambda application: application.name) - - -def get_application( - application_name: str, system_name: Optional[str] = None -) -> List["Application"]: - """Return a list of application instances with provided name.""" - return [ - application - for application in get_available_applications() - if application.name == application_name - and (not system_name or application.can_run_on(system_name)) - ] - - -def install_application(source_path: Path) -> None: - """Install application.""" - try: - source = get_source(source_path) - config = cast(List[ExtendedApplicationConfig], source.config()) - applications_to_install = [ - s for entry in config for s in load_applications(entry) - ] - except Exception as error: - raise ConfigurationException("Unable to read application definition") from error - - if not applications_to_install: - raise ConfigurationException("No application definition found") - - available_applications = get_available_applications() - already_installed = [ - s for s in applications_to_install if s in available_applications - ] - if already_installed: - names = {application.name for application in already_installed} - raise ConfigurationException( - "Applications [{}] are already installed".format(",".join(names)) - ) - - create_destination_and_install(source, get_resources("applications")) - - -def remove_application(directory_name: str) -> None: - """Remove application directory.""" - remove_backend(directory_name, "applications") - - -def get_unique_application_names(system_name: Optional[str] = None) -> List[str]: - """Extract a list of unique application names of all application available.""" - return list( - set( - application.name - for application in get_available_applications() - if not system_name or application.can_run_on(system_name) - ) - ) - - -class Application(Backend): - """Class for representing a single application component.""" - - def __init__(self, config: ApplicationConfig) -> None: - """Construct a Application instance from a dict.""" - super().__init__(config) - - self.supported_systems = config.get("supported_systems", []) - self.deploy_data = config.get("deploy_data", []) - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Application): - return False - - return ( - super().__eq__(other) - and self.name == other.name - and set(self.supported_systems) == set(other.supported_systems) - ) - - def can_run_on(self, system_name: str) -> bool: - """Check if the application can run on the system passed as argument.""" - return system_name in self.supported_systems - - def get_deploy_data(self) -> List[DataPaths]: - """Validate and return data specified in the config file.""" - if self.config_location is None: - raise ConfigurationException( - "Unable to get application {} config location".format(self.name) - ) - - deploy_data = [] - for item in self.deploy_data: - src, dst = item - src_full_path = self.config_location / src - assert src_full_path.exists(), "{} does not exists".format(src_full_path) - deploy_data.append(DataPaths(src_full_path, dst)) - return deploy_data - - def get_details(self) -> Dict[str, Any]: - """Return dictionary with information about the Application instance.""" - output = { - "type": "application", - "name": self.name, - "description": self.description, - "supported_systems": self.supported_systems, - "commands": self._get_command_details(), - } - - return output - - def remove_unused_params(self) -> None: - """Remove unused params in commands. - - After merging default and system related configuration application - could have parameters that are not being used in commands. They - should be removed. - """ - for command in self.commands.values(): - indexes_or_aliases = [ - m - for cmd_str in command.command_strings - for m in re.findall(r"{user_params:(?P<index_or_alias>\w+)}", cmd_str) - ] - - only_aliases = all(not item.isnumeric() for item in indexes_or_aliases) - if only_aliases: - used_params = [ - param - for param in command.params - if param.alias in indexes_or_aliases - ] - command.params = used_params - - -def load_applications(config: ExtendedApplicationConfig) -> List[Application]: - """Load application. - - Application configuration could contain different parameters/commands for different - supported systems. For each supported system this function will return separate - Application instance with appropriate configuration. - """ - configs = load_application_or_tool_configs(config, ApplicationConfig) - applications = [Application(cfg) for cfg in configs] - for application in applications: - application.remove_unused_params() - return applications diff --git a/src/aiet/backend/common.py b/src/aiet/backend/common.py deleted file mode 100644 index b887ee7..0000000 --- a/src/aiet/backend/common.py +++ /dev/null @@ -1,532 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain all common functions for the backends.""" -import json -import logging -import re -from abc import ABC -from collections import Counter -from pathlib import Path -from typing import Any -from typing import Callable -from typing import cast -from typing import Dict -from typing import Final -from typing import IO -from typing import Iterable -from typing import List -from typing import Match -from typing import NamedTuple -from typing import Optional -from typing import Pattern -from typing import Tuple -from typing import Type -from typing import Union - -from aiet.backend.config import BackendConfig -from aiet.backend.config import BaseBackendConfig -from aiet.backend.config import NamedExecutionConfig -from aiet.backend.config import UserParamConfig -from aiet.backend.config import UserParamsConfig -from aiet.utils.fs import get_resources -from aiet.utils.fs import remove_resource -from aiet.utils.fs import ResourceType - - -AIET_CONFIG_FILE: Final[str] = "aiet-config.json" - - -class ConfigurationException(Exception): - """Configuration exception.""" - - -def get_backend_config(dir_path: Path) -> Path: - """Get path to backendir configuration file.""" - return dir_path / AIET_CONFIG_FILE - - -def get_backend_configs(resource_type: ResourceType) -> Iterable[Path]: - """Get path to the backend configs for provided resource_type.""" - return ( - get_backend_config(entry) for entry in get_backend_directories(resource_type) - ) - - -def get_backend_directories(resource_type: ResourceType) -> Iterable[Path]: - """Get path to the backend directories for provided resource_type.""" - return ( - entry - for entry in get_resources(resource_type).iterdir() - if is_backend_directory(entry) - ) - - -def is_backend_directory(dir_path: Path) -> bool: - """Check if path is backend's configuration directory.""" - return dir_path.is_dir() and get_backend_config(dir_path).is_file() - - -def remove_backend(directory_name: str, resource_type: ResourceType) -> None: - """Remove backend with provided type and directory_name.""" - if not directory_name: - raise Exception("No directory name provided") - - remove_resource(directory_name, resource_type) - - -def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig: - """Return a loaded json file.""" - if config is None: - raise Exception("Unable to read config") - - if isinstance(config, Path): - with config.open() as json_file: - return cast(BackendConfig, json.load(json_file)) - - return cast(BackendConfig, json.load(config)) - - -def parse_raw_parameter(parameter: str) -> Tuple[str, Optional[str]]: - """Split the parameter string in name and optional value. - - It manages the following cases: - --param=1 -> --param, 1 - --param 1 -> --param, 1 - --flag -> --flag, None - """ - data = re.split(" |=", parameter) - if len(data) == 1: - param_name = data[0] - param_value = None - else: - param_name = " ".join(data[0:-1]) - param_value = data[-1] - return param_name, param_value - - -class DataPaths(NamedTuple): - """DataPaths class.""" - - src: Path - dst: str - - -class Backend(ABC): - """Backend class.""" - - # pylint: disable=too-many-instance-attributes - - def __init__(self, config: BaseBackendConfig): - """Initialize backend.""" - name = config.get("name") - if not name: - raise ConfigurationException("Name is empty") - - self.name = name - self.description = config.get("description", "") - self.config_location = config.get("config_location") - self.variables = config.get("variables", {}) - self.build_dir = config.get("build_dir") - self.lock = config.get("lock", False) - if self.build_dir: - self.build_dir = self._substitute_variables(self.build_dir) - self.annotations = config.get("annotations", {}) - - self._parse_commands_and_params(config) - - def validate_parameter(self, command_name: str, parameter: str) -> bool: - """Validate the parameter string against the application configuration. - - We take the parameter string, extract the parameter name/value and - check them against the current configuration. - """ - param_name, param_value = parse_raw_parameter(parameter) - valid_param_name = valid_param_value = False - - command = self.commands.get(command_name) - if not command: - raise AttributeError("Unknown command: '{}'".format(command_name)) - - # Iterate over all available parameters until we have a match. - for param in command.params: - if self._same_parameter(param_name, param): - valid_param_name = True - # This is a non-empty list - if param.values: - # We check if the value is allowed in the configuration - valid_param_value = param_value in param.values - else: - # In this case we don't validate the value and accept - # whatever we have set. - valid_param_value = True - break - - return valid_param_name and valid_param_value - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Backend): - return False - - return ( - self.name == other.name - and self.description == other.description - and self.commands == other.commands - ) - - def __repr__(self) -> str: - """Represent the Backend instance by its name.""" - return self.name - - def _parse_commands_and_params(self, config: BaseBackendConfig) -> None: - """Parse commands and user parameters.""" - self.commands: Dict[str, Command] = {} - - commands = config.get("commands") - if commands: - params = config.get("user_params") - - for command_name in commands.keys(): - command_params = self._parse_params(params, command_name) - command_strings = [ - self._substitute_variables(cmd) - for cmd in commands.get(command_name, []) - ] - self.commands[command_name] = Command(command_strings, command_params) - - def _substitute_variables(self, str_val: str) -> str: - """Substitute variables in string. - - Variables is being substituted at backend's creation stage because - they could contain references to other params which will be - resolved later. - """ - if not str_val: - return str_val - - var_pattern: Final[Pattern] = re.compile(r"{variables:(?P<var_name>\w+)}") - - def var_value(match: Match) -> str: - var_name = match["var_name"] - if var_name not in self.variables: - raise ConfigurationException("Unknown variable {}".format(var_name)) - - return self.variables[var_name] - - return var_pattern.sub(var_value, str_val) # type: ignore - - @classmethod - def _parse_params( - cls, params: Optional[UserParamsConfig], command: str - ) -> List["Param"]: - if not params: - return [] - - return [cls._parse_param(p) for p in params.get(command, [])] - - @classmethod - def _parse_param(cls, param: UserParamConfig) -> "Param": - """Parse a single parameter.""" - name = param.get("name") - if name is not None and not name: - raise ConfigurationException("Parameter has an empty 'name' attribute.") - values = param.get("values", None) - default_value = param.get("default_value", None) - description = param.get("description", "") - alias = param.get("alias") - - return Param( - name=name, - description=description, - values=values, - default_value=default_value, - alias=alias, - ) - - def _get_command_details(self) -> Dict: - command_details = { - command_name: command.get_details() - for command_name, command in self.commands.items() - } - return command_details - - def _get_user_param_value( - self, user_params: List[str], param: "Param" - ) -> Optional[str]: - """Get the user-specified value of a parameter.""" - for user_param in user_params: - user_param_name, user_param_value = parse_raw_parameter(user_param) - if user_param_name == param.name: - warn_message = ( - "The direct use of parameter name is deprecated" - " and might be removed in the future.\n" - f"Please use alias '{param.alias}' instead of " - "'{user_param_name}' to provide the parameter." - ) - logging.warning(warn_message) - - if self._same_parameter(user_param_name, param): - return user_param_value - - return None - - @staticmethod - def _same_parameter(user_param_name_or_alias: str, param: "Param") -> bool: - """Compare user parameter name with param name or alias.""" - # Strip the "=" sign in the param_name. This is needed just for - # comparison with the parameters passed by the user. - # The equal sign needs to be honoured when re-building the - # parameter back. - param_name = None if not param.name else param.name.rstrip("=") - return user_param_name_or_alias in [param_name, param.alias] - - def resolved_parameters( - self, command_name: str, user_params: List[str] - ) -> List[Tuple[Optional[str], "Param"]]: - """Return list of parameters with values.""" - result: List[Tuple[Optional[str], "Param"]] = [] - command = self.commands.get(command_name) - if not command: - return result - - for param in command.params: - value = self._get_user_param_value(user_params, param) - if not value: - value = param.default_value - result.append((value, param)) - - return result - - def build_command( - self, - command_name: str, - user_params: List[str], - param_resolver: Callable[[str, str, List[Tuple[Optional[str], "Param"]]], str], - ) -> List[str]: - """ - Return a list of executable command strings. - - Given a command and associated parameters, returns a list of executable command - strings. - """ - command = self.commands.get(command_name) - if not command: - raise ConfigurationException( - "Command '{}' could not be found.".format(command_name) - ) - - commands_to_run = [] - - params_values = self.resolved_parameters(command_name, user_params) - for cmd_str in command.command_strings: - cmd_str = resolve_all_parameters( - cmd_str, param_resolver, command_name, params_values - ) - commands_to_run.append(cmd_str) - - return commands_to_run - - -class Param: - """Class for representing a generic application parameter.""" - - def __init__( # pylint: disable=too-many-arguments - self, - name: Optional[str], - description: str, - values: Optional[List[str]] = None, - default_value: Optional[str] = None, - alias: Optional[str] = None, - ) -> None: - """Construct a Param instance.""" - if not name and not alias: - raise ConfigurationException( - "Either name, alias or both must be set to identify a parameter." - ) - self.name = name - self.values = values - self.description = description - self.default_value = default_value - self.alias = alias - - def get_details(self) -> Dict: - """Return a dictionary with all relevant information of a Param.""" - return {key: value for key, value in self.__dict__.items() if value} - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Param): - return False - - return ( - self.name == other.name - and self.values == other.values - and self.default_value == other.default_value - and self.description == other.description - ) - - -class Command: - """Class for representing a command.""" - - def __init__( - self, command_strings: List[str], params: Optional[List[Param]] = None - ) -> None: - """Construct a Command instance.""" - self.command_strings = command_strings - - if params: - self.params = params - else: - self.params = [] - - self._validate() - - def _validate(self) -> None: - """Validate command.""" - if not self.params: - return - - aliases = [param.alias for param in self.params if param.alias is not None] - repeated_aliases = [ - alias for alias, count in Counter(aliases).items() if count > 1 - ] - - if repeated_aliases: - raise ConfigurationException( - "Non unique aliases {}".format(", ".join(repeated_aliases)) - ) - - both_name_and_alias = [ - param.name - for param in self.params - if param.name in aliases and param.name != param.alias - ] - if both_name_and_alias: - raise ConfigurationException( - "Aliases {} could not be used as parameter name".format( - ", ".join(both_name_and_alias) - ) - ) - - def get_details(self) -> Dict: - """Return a dictionary with all relevant information of a Command.""" - output = { - "command_strings": self.command_strings, - "user_params": [param.get_details() for param in self.params], - } - return output - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Command): - return False - - return ( - self.command_strings == other.command_strings - and self.params == other.params - ) - - -def resolve_all_parameters( - str_val: str, - param_resolver: Callable[[str, str, List[Tuple[Optional[str], Param]]], str], - command_name: Optional[str] = None, - params_values: Optional[List[Tuple[Optional[str], Param]]] = None, -) -> str: - """Resolve all parameters in the string.""" - if not str_val: - return str_val - - param_pattern: Final[Pattern] = re.compile(r"{(?P<param_name>[\w.:]+)}") - while param_pattern.findall(str_val): - str_val = param_pattern.sub( - lambda m: param_resolver( - m["param_name"], command_name or "", params_values or [] - ), - str_val, - ) - return str_val - - -def load_application_or_tool_configs( - config: Any, - config_type: Type[Any], - is_system_required: bool = True, -) -> Any: - """Get one config for each system supported by the application/tool. - - The configuration could contain different parameters/commands for different - supported systems. For each supported system this function will return separate - config with appropriate configuration. - """ - merged_configs = [] - supported_systems: Optional[List[NamedExecutionConfig]] = config.get( - "supported_systems" - ) - if not supported_systems: - if is_system_required: - raise ConfigurationException("No supported systems definition provided") - # Create an empty system to be used in the parsing below - supported_systems = [cast(NamedExecutionConfig, {})] - - default_user_params = config.get("user_params", {}) - - def merge_config(system: NamedExecutionConfig) -> Any: - system_name = system.get("name") - if not system_name and is_system_required: - raise ConfigurationException( - "Unable to read supported system definition, name is missed" - ) - - merged_config = config_type(**config) - merged_config["supported_systems"] = [system_name] if system_name else [] - # merge default configuration and specific to the system - merged_config["commands"] = { - **config.get("commands", {}), - **system.get("commands", {}), - } - - params = {} - tool_user_params = system.get("user_params", {}) - command_names = tool_user_params.keys() | default_user_params.keys() - for command_name in command_names: - if command_name not in merged_config["commands"]: - continue - - params_default = default_user_params.get(command_name, []) - params_tool = tool_user_params.get(command_name, []) - if not params_default or not params_tool: - params[command_name] = params_tool or params_default - if params_default and params_tool: - if any(not p.get("alias") for p in params_default): - raise ConfigurationException( - "Default parameters for command {} should have aliases".format( - command_name - ) - ) - if any(not p.get("alias") for p in params_tool): - raise ConfigurationException( - "{} parameters for command {} should have aliases".format( - system_name, command_name - ) - ) - - merged_by_alias = { - **{p.get("alias"): p for p in params_default}, - **{p.get("alias"): p for p in params_tool}, - } - params[command_name] = list(merged_by_alias.values()) - - merged_config["user_params"] = params - merged_config["build_dir"] = system.get("build_dir", config.get("build_dir")) - merged_config["lock"] = system.get("lock", config.get("lock", False)) - merged_config["variables"] = { - **config.get("variables", {}), - **system.get("variables", {}), - } - return merged_config - - merged_configs = [merge_config(system) for system in supported_systems] - - return merged_configs diff --git a/src/aiet/backend/config.py b/src/aiet/backend/config.py deleted file mode 100644 index dd42012..0000000 --- a/src/aiet/backend/config.py +++ /dev/null @@ -1,107 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain definition of backend configuration.""" -from pathlib import Path -from typing import Dict -from typing import List -from typing import Literal -from typing import Optional -from typing import Tuple -from typing import TypedDict -from typing import Union - - -class UserParamConfig(TypedDict, total=False): - """User parameter configuration.""" - - name: Optional[str] - default_value: str - values: List[str] - description: str - alias: str - - -UserParamsConfig = Dict[str, List[UserParamConfig]] - - -class ExecutionConfig(TypedDict, total=False): - """Execution configuration.""" - - commands: Dict[str, List[str]] - user_params: UserParamsConfig - build_dir: str - variables: Dict[str, str] - lock: bool - - -class NamedExecutionConfig(ExecutionConfig): - """Execution configuration with name.""" - - name: str - - -class BaseBackendConfig(ExecutionConfig, total=False): - """Base backend configuration.""" - - name: str - description: str - config_location: Path - annotations: Dict[str, Union[str, List[str]]] - - -class ApplicationConfig(BaseBackendConfig, total=False): - """Application configuration.""" - - supported_systems: List[str] - deploy_data: List[Tuple[str, str]] - - -class ExtendedApplicationConfig(BaseBackendConfig, total=False): - """Extended application configuration.""" - - supported_systems: List[NamedExecutionConfig] - deploy_data: List[Tuple[str, str]] - - -class ProtocolConfig(TypedDict, total=False): - """Protocol config.""" - - protocol: Literal["local", "ssh"] - - -class SSHConfig(ProtocolConfig, total=False): - """SSH configuration.""" - - username: str - password: str - hostname: str - port: str - - -class LocalProtocolConfig(ProtocolConfig, total=False): - """Local protocol config.""" - - -class SystemConfig(BaseBackendConfig, total=False): - """System configuration.""" - - data_transfer: Union[SSHConfig, LocalProtocolConfig] - reporting: Dict[str, Dict] - - -class ToolConfig(BaseBackendConfig, total=False): - """Tool configuration.""" - - supported_systems: List[str] - - -class ExtendedToolConfig(BaseBackendConfig, total=False): - """Extended tool configuration.""" - - supported_systems: List[NamedExecutionConfig] - - -BackendItemConfig = Union[ApplicationConfig, SystemConfig, ToolConfig] -BackendConfig = Union[ - List[ExtendedApplicationConfig], List[SystemConfig], List[ToolConfig] -] diff --git a/src/aiet/backend/controller.py b/src/aiet/backend/controller.py deleted file mode 100644 index 2650902..0000000 --- a/src/aiet/backend/controller.py +++ /dev/null @@ -1,134 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Controller backend module.""" -import time -from pathlib import Path -from typing import List -from typing import Optional -from typing import Tuple - -import psutil -import sh - -from aiet.backend.common import ConfigurationException -from aiet.utils.fs import read_file_as_string -from aiet.utils.proc import execute_command -from aiet.utils.proc import get_stdout_stderr_paths -from aiet.utils.proc import read_process_info -from aiet.utils.proc import save_process_info -from aiet.utils.proc import terminate_command -from aiet.utils.proc import terminate_external_process - - -class SystemController: - """System controller class.""" - - def __init__(self) -> None: - """Create new instance of service controller.""" - self.cmd: Optional[sh.RunningCommand] = None - self.out_path: Optional[Path] = None - self.err_path: Optional[Path] = None - - def before_start(self) -> None: - """Run actions before system start.""" - - def after_start(self) -> None: - """Run actions after system start.""" - - def start(self, commands: List[str], cwd: Path) -> None: - """Start system.""" - if not isinstance(cwd, Path) or not cwd.is_dir(): - raise ConfigurationException("Wrong working directory {}".format(cwd)) - - if len(commands) != 1: - raise ConfigurationException("System should have only one command to run") - - startup_command = commands[0] - if not startup_command: - raise ConfigurationException("No startup command provided") - - self.before_start() - - self.out_path, self.err_path = get_stdout_stderr_paths(startup_command) - - self.cmd = execute_command( - startup_command, - cwd, - bg=True, - out=str(self.out_path), - err=str(self.err_path), - ) - - self.after_start() - - def stop( - self, wait: bool = False, wait_period: float = 0.5, number_of_attempts: int = 20 - ) -> None: - """Stop system.""" - if self.cmd is not None and self.is_running(): - terminate_command(self.cmd, wait, wait_period, number_of_attempts) - - def is_running(self) -> bool: - """Check if underlying process is running.""" - return self.cmd is not None and self.cmd.is_alive() - - def get_output(self) -> Tuple[str, str]: - """Return application output.""" - if self.cmd is None or self.out_path is None or self.err_path is None: - return ("", "") - - return (read_file_as_string(self.out_path), read_file_as_string(self.err_path)) - - -class SystemControllerSingleInstance(SystemController): - """System controller with support of system's single instance.""" - - def __init__(self, pid_file_path: Optional[Path] = None) -> None: - """Create new instance of the service controller.""" - super().__init__() - self.pid_file_path = pid_file_path - - def before_start(self) -> None: - """Run actions before system start.""" - self._check_if_previous_instance_is_running() - - def after_start(self) -> None: - """Run actions after system start.""" - self._save_process_info() - - def _check_if_previous_instance_is_running(self) -> None: - """Check if another instance of the system is running.""" - process_info = read_process_info(self._pid_file()) - - for item in process_info: - try: - process = psutil.Process(item.pid) - same_process = ( - process.name() == item.name - and process.exe() == item.executable - and process.cwd() == item.cwd - ) - if same_process: - print( - "Stopping previous instance of the system [{}]".format(item.pid) - ) - terminate_external_process(process) - except psutil.NoSuchProcess: - pass - - def _save_process_info(self, wait_period: float = 2) -> None: - """Save information about system's processes.""" - if self.cmd is None or not self.is_running(): - return - - # give some time for the system to start - time.sleep(wait_period) - - save_process_info(self.cmd.process.pid, self._pid_file()) - - def _pid_file(self) -> Path: - """Return path to file which is used for saving process info.""" - if not self.pid_file_path: - raise Exception("No pid file path presented") - - return self.pid_file_path diff --git a/src/aiet/backend/execution.py b/src/aiet/backend/execution.py deleted file mode 100644 index 1653ee2..0000000 --- a/src/aiet/backend/execution.py +++ /dev/null @@ -1,859 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Application execution module.""" -import itertools -import json -import random -import re -import string -import sys -import time -import warnings -from collections import defaultdict -from contextlib import contextmanager -from contextlib import ExitStack -from pathlib import Path -from typing import Any -from typing import Callable -from typing import cast -from typing import ContextManager -from typing import Dict -from typing import Generator -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypedDict -from typing import Union - -from filelock import FileLock -from filelock import Timeout - -from aiet.backend.application import Application -from aiet.backend.application import get_application -from aiet.backend.common import Backend -from aiet.backend.common import ConfigurationException -from aiet.backend.common import DataPaths -from aiet.backend.common import Param -from aiet.backend.common import parse_raw_parameter -from aiet.backend.common import resolve_all_parameters -from aiet.backend.output_parser import Base64OutputParser -from aiet.backend.output_parser import OutputParser -from aiet.backend.output_parser import RegexOutputParser -from aiet.backend.system import ControlledSystem -from aiet.backend.system import get_system -from aiet.backend.system import StandaloneSystem -from aiet.backend.system import System -from aiet.backend.tool import get_tool -from aiet.backend.tool import Tool -from aiet.utils.fs import recreate_directory -from aiet.utils.fs import remove_directory -from aiet.utils.fs import valid_for_filename -from aiet.utils.proc import run_and_wait - - -class AnotherInstanceIsRunningException(Exception): - """Concurrent execution error.""" - - -class ConnectionException(Exception): - """Connection exception.""" - - -class ExecutionParams(TypedDict, total=False): - """Execution parameters.""" - - disable_locking: bool - unique_build_dir: bool - - -class ExecutionContext: - """Command execution context.""" - - # pylint: disable=too-many-arguments,too-many-instance-attributes - def __init__( - self, - app: Union[Application, Tool], - app_params: List[str], - system: Optional[System], - system_params: List[str], - custom_deploy_data: Optional[List[DataPaths]] = None, - execution_params: Optional[ExecutionParams] = None, - report_file: Optional[Path] = None, - ): - """Init execution context.""" - self.app = app - self.app_params = app_params - self.custom_deploy_data = custom_deploy_data or [] - self.system = system - self.system_params = system_params - self.execution_params = execution_params or ExecutionParams() - self.report_file = report_file - - self.reporter: Optional[Reporter] - if self.report_file: - # Create reporter with output parsers - parsers: List[OutputParser] = [] - if system and system.reporting: - # Add RegexOutputParser, if it is configured in the system - parsers.append(RegexOutputParser("system", system.reporting["regex"])) - # Add Base64 parser for applications - parsers.append(Base64OutputParser("application")) - self.reporter = Reporter(parsers=parsers) - else: - self.reporter = None # No reporter needed. - - self.param_resolver = ParamResolver(self) - self._resolved_build_dir: Optional[Path] = None - - @property - def is_deploy_needed(self) -> bool: - """Check if application requires data deployment.""" - if isinstance(self.app, Application): - return ( - len(self.app.get_deploy_data()) > 0 or len(self.custom_deploy_data) > 0 - ) - return False - - @property - def is_locking_required(self) -> bool: - """Return true if any form of locking required.""" - return not self._disable_locking() and ( - self.app.lock or (self.system is not None and self.system.lock) - ) - - @property - def is_build_required(self) -> bool: - """Return true if application build required.""" - return "build" in self.app.commands - - @property - def is_unique_build_dir_required(self) -> bool: - """Return true if unique build dir required.""" - return self.execution_params.get("unique_build_dir", False) - - def build_dir(self) -> Path: - """Return resolved application build dir.""" - if self._resolved_build_dir is not None: - return self._resolved_build_dir - - if ( - not isinstance(self.app.config_location, Path) - or not self.app.config_location.is_dir() - ): - raise ConfigurationException( - "Application {} has wrong config location".format(self.app.name) - ) - - _build_dir = self.app.build_dir - if _build_dir: - _build_dir = resolve_all_parameters(_build_dir, self.param_resolver) - - if not _build_dir: - raise ConfigurationException( - "No build directory defined for the app {}".format(self.app.name) - ) - - if self.is_unique_build_dir_required: - random_suffix = "".join( - random.choices(string.ascii_lowercase + string.digits, k=7) - ) - _build_dir = "{}_{}".format(_build_dir, random_suffix) - - self._resolved_build_dir = self.app.config_location / _build_dir - return self._resolved_build_dir - - def _disable_locking(self) -> bool: - """Return true if locking should be disabled.""" - return self.execution_params.get("disable_locking", False) - - -class ParamResolver: - """Parameter resolver.""" - - def __init__(self, context: ExecutionContext): - """Init parameter resolver.""" - self.ctx = context - - @staticmethod - def resolve_user_params( - cmd_name: Optional[str], - index_or_alias: str, - resolved_params: Optional[List[Tuple[Optional[str], Param]]], - ) -> str: - """Resolve user params.""" - if not cmd_name or resolved_params is None: - raise ConfigurationException("Unable to resolve user params") - - param_value: Optional[str] = None - param: Optional[Param] = None - - if index_or_alias.isnumeric(): - i = int(index_or_alias) - if i not in range(len(resolved_params)): - raise ConfigurationException( - "Invalid index {} for user params of command {}".format(i, cmd_name) - ) - param_value, param = resolved_params[i] - else: - for val, par in resolved_params: - if par.alias == index_or_alias: - param_value, param = val, par - break - - if param is None: - raise ConfigurationException( - "No user parameter for command '{}' with alias '{}'.".format( - cmd_name, index_or_alias - ) - ) - - if param_value: - # We need to handle to cases of parameters here: - # 1) Optional parameters (non-positional with a name and value) - # 2) Positional parameters (value only, no name needed) - # Default to empty strings for positional arguments - param_name = "" - separator = "" - if param.name is not None: - # A valid param name means we have an optional/non-positional argument: - # The separator is an empty string in case the param_name - # has an equal sign as we have to honour it. - # If the parameter doesn't end with an equal sign then a - # space character is injected to split the parameter name - # and its value - param_name = param.name - separator = "" if param.name.endswith("=") else " " - - return "{param_name}{separator}{param_value}".format( - param_name=param_name, - separator=separator, - param_value=param_value, - ) - - if param.name is None: - raise ConfigurationException( - "Missing user parameter with alias '{}' for command '{}'.".format( - index_or_alias, cmd_name - ) - ) - - return param.name # flag: just return the parameter name - - def resolve_commands_and_params( - self, backend_type: str, cmd_name: str, return_params: bool, index_or_alias: str - ) -> str: - """Resolve command or command's param value.""" - if backend_type == "system": - backend = cast(Backend, self.ctx.system) - backend_params = self.ctx.system_params - else: # Application or Tool backend - backend = cast(Backend, self.ctx.app) - backend_params = self.ctx.app_params - - if cmd_name not in backend.commands: - raise ConfigurationException("Command {} not found".format(cmd_name)) - - if return_params: - params = backend.resolved_parameters(cmd_name, backend_params) - if index_or_alias.isnumeric(): - i = int(index_or_alias) - if i not in range(len(params)): - raise ConfigurationException( - "Invalid parameter index {} for command {}".format(i, cmd_name) - ) - - param_value = params[i][0] - else: - param_value = None - for value, param in params: - if param.alias == index_or_alias: - param_value = value - break - - if not param_value: - raise ConfigurationException( - ( - "No value for parameter with index or alias {} of command {}" - ).format(index_or_alias, cmd_name) - ) - return param_value - - if not index_or_alias.isnumeric(): - raise ConfigurationException("Bad command index {}".format(index_or_alias)) - - i = int(index_or_alias) - commands = backend.build_command(cmd_name, backend_params, self.param_resolver) - if i not in range(len(commands)): - raise ConfigurationException( - "Invalid index {} for command {}".format(i, cmd_name) - ) - - return commands[i] - - def resolve_variables(self, backend_type: str, var_name: str) -> str: - """Resolve variable value.""" - if backend_type == "system": - backend = cast(Backend, self.ctx.system) - else: # Application or Tool backend - backend = cast(Backend, self.ctx.app) - - if var_name not in backend.variables: - raise ConfigurationException("Unknown variable {}".format(var_name)) - - return backend.variables[var_name] - - def param_matcher( - self, - param_name: str, - cmd_name: Optional[str], - resolved_params: Optional[List[Tuple[Optional[str], Param]]], - ) -> str: - """Regexp to resolve a param from the param_name.""" - # this pattern supports parameter names like "application.commands.run:0" and - # "system.commands.run.params:0" - # Note: 'software' is included for backward compatibility. - commands_and_params_match = re.match( - r"(?P<type>application|software|tool|system)[.]commands[.]" - r"(?P<name>\w+)" - r"(?P<params>[.]params|)[:]" - r"(?P<index_or_alias>\w+)", - param_name, - ) - - if commands_and_params_match: - backend_type, cmd_name, return_params, index_or_alias = ( - commands_and_params_match["type"], - commands_and_params_match["name"], - commands_and_params_match["params"], - commands_and_params_match["index_or_alias"], - ) - return self.resolve_commands_and_params( - backend_type, cmd_name, bool(return_params), index_or_alias - ) - - # Note: 'software' is included for backward compatibility. - variables_match = re.match( - r"(?P<type>application|software|tool|system)[.]variables:(?P<var_name>\w+)", - param_name, - ) - if variables_match: - backend_type, var_name = ( - variables_match["type"], - variables_match["var_name"], - ) - return self.resolve_variables(backend_type, var_name) - - user_params_match = re.match(r"user_params:(?P<index_or_alias>\w+)", param_name) - if user_params_match: - index_or_alias = user_params_match["index_or_alias"] - return self.resolve_user_params(cmd_name, index_or_alias, resolved_params) - - raise ConfigurationException( - "Unable to resolve parameter {}".format(param_name) - ) - - def param_resolver( - self, - param_name: str, - cmd_name: Optional[str] = None, - resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, - ) -> str: - """Resolve parameter value based on current execution context.""" - # Note: 'software.*' is included for backward compatibility. - resolved_param = None - if param_name in ["application.name", "tool.name", "software.name"]: - resolved_param = self.ctx.app.name - elif param_name in [ - "application.description", - "tool.description", - "software.description", - ]: - resolved_param = self.ctx.app.description - elif self.ctx.app.config_location and ( - param_name - in ["application.config_dir", "tool.config_dir", "software.config_dir"] - ): - resolved_param = str(self.ctx.app.config_location.absolute()) - elif self.ctx.app.build_dir and ( - param_name - in ["application.build_dir", "tool.build_dir", "software.build_dir"] - ): - resolved_param = str(self.ctx.build_dir().absolute()) - elif self.ctx.system is not None: - if param_name == "system.name": - resolved_param = self.ctx.system.name - elif param_name == "system.description": - resolved_param = self.ctx.system.description - elif param_name == "system.config_dir" and self.ctx.system.config_location: - resolved_param = str(self.ctx.system.config_location.absolute()) - - if not resolved_param: - resolved_param = self.param_matcher(param_name, cmd_name, resolved_params) - return resolved_param - - def __call__( - self, - param_name: str, - cmd_name: Optional[str] = None, - resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, - ) -> str: - """Resolve provided parameter.""" - return self.param_resolver(param_name, cmd_name, resolved_params) - - -class Reporter: - """Report metrics from the simulation output.""" - - def __init__(self, parsers: Optional[List[OutputParser]] = None) -> None: - """Create an empty reporter (i.e. no parsers registered).""" - self.parsers: List[OutputParser] = parsers if parsers is not None else [] - self._report: Dict[str, Any] = defaultdict(lambda: defaultdict(dict)) - - def parse(self, output: bytearray) -> None: - """Parse output and append parsed metrics to internal report dict.""" - for parser in self.parsers: - # Merge metrics from different parsers (do not overwrite) - self._report[parser.name]["metrics"].update(parser(output)) - - def get_filtered_output(self, output: bytearray) -> bytearray: - """Filter the output according to each parser.""" - for parser in self.parsers: - output = parser.filter_out_parsed_content(output) - return output - - def report(self, ctx: ExecutionContext) -> Dict[str, Any]: - """Add static simulation info to parsed data and return the report.""" - report: Dict[str, Any] = defaultdict(dict) - # Add static simulation info - report.update(self._static_info(ctx)) - # Add metrics parsed from the output - for key, val in self._report.items(): - report[key].update(val) - return report - - @staticmethod - def save(report: Dict[str, Any], report_file: Path) -> None: - """Save the report to a JSON file.""" - with open(report_file, "w", encoding="utf-8") as file: - json.dump(report, file, indent=4) - - @staticmethod - def _compute_all_params(cli_params: List[str], backend: Backend) -> Dict[str, str]: - """ - Build a dict of all parameters, {name:value}. - - Param values taken from command line if specified, defaults otherwise. - """ - # map of params passed from the cli ["p1=v1","p2=v2"] -> {"p1":"v1", "p2":"v2"} - app_params_map = dict(parse_raw_parameter(expr) for expr in cli_params) - - # a map of params declared in the application, with values taken from the CLI, - # defaults otherwise - all_params = { - (p.alias or p.name): app_params_map.get( - cast(str, p.name), cast(str, p.default_value) - ) - for cmd in backend.commands.values() - for p in cmd.params - } - return cast(Dict[str, str], all_params) - - @staticmethod - def _static_info(ctx: ExecutionContext) -> Dict[str, Any]: - """Extract static simulation information from the context.""" - if ctx.system is None: - raise ValueError("No system available to report.") - - info = { - "system": { - "name": ctx.system.name, - "params": Reporter._compute_all_params(ctx.system_params, ctx.system), - }, - "application": { - "name": ctx.app.name, - "params": Reporter._compute_all_params(ctx.app_params, ctx.app), - }, - } - return info - - -def validate_parameters( - backend: Backend, command_names: List[str], params: List[str] -) -> None: - """Check parameters passed to backend.""" - for param in params: - acceptable = any( - backend.validate_parameter(command_name, param) - for command_name in command_names - if command_name in backend.commands - ) - - if not acceptable: - backend_type = "System" if isinstance(backend, System) else "Application" - raise ValueError( - "{} parameter '{}' not valid for command '{}'".format( - backend_type, param, " or ".join(command_names) - ) - ) - - -def get_application_by_name_and_system( - application_name: str, system_name: str -) -> Application: - """Get application.""" - applications = get_application(application_name, system_name) - if not applications: - raise ValueError( - "Application '{}' doesn't support the system '{}'".format( - application_name, system_name - ) - ) - - if len(applications) != 1: - raise ValueError( - "Error during getting application {} for the system {}".format( - application_name, system_name - ) - ) - - return applications[0] - - -def get_application_and_system( - application_name: str, system_name: str -) -> Tuple[Application, System]: - """Return application and system by provided names.""" - system = get_system(system_name) - if not system: - raise ValueError("System {} is not found".format(system_name)) - - application = get_application_by_name_and_system(application_name, system_name) - - return application, system - - -def execute_application_command( # pylint: disable=too-many-arguments - command_name: str, - application_name: str, - application_params: List[str], - system_name: str, - system_params: List[str], - custom_deploy_data: List[DataPaths], -) -> None: - """Execute application command. - - .. deprecated:: 21.12 - """ - warnings.warn( - "Use 'run_application()' instead. Use of 'execute_application_command()' is " - "deprecated and might be removed in a future release.", - DeprecationWarning, - ) - - if command_name not in ["build", "run"]: - raise ConfigurationException("Unsupported command {}".format(command_name)) - - application, system = get_application_and_system(application_name, system_name) - validate_parameters(application, [command_name], application_params) - validate_parameters(system, [command_name], system_params) - - ctx = ExecutionContext( - app=application, - app_params=application_params, - system=system, - system_params=system_params, - custom_deploy_data=custom_deploy_data, - ) - - if command_name == "run": - execute_application_command_run(ctx) - else: - execute_application_command_build(ctx) - - -# pylint: disable=too-many-arguments -def run_application( - application_name: str, - application_params: List[str], - system_name: str, - system_params: List[str], - custom_deploy_data: List[DataPaths], - report_file: Optional[Path] = None, -) -> None: - """Run application on the provided system.""" - application, system = get_application_and_system(application_name, system_name) - validate_parameters(application, ["build", "run"], application_params) - validate_parameters(system, ["build", "run"], system_params) - - execution_params = ExecutionParams() - if isinstance(system, StandaloneSystem): - execution_params["disable_locking"] = True - execution_params["unique_build_dir"] = True - - ctx = ExecutionContext( - app=application, - app_params=application_params, - system=system, - system_params=system_params, - custom_deploy_data=custom_deploy_data, - execution_params=execution_params, - report_file=report_file, - ) - - with build_dir_manager(ctx): - if ctx.is_build_required: - execute_application_command_build(ctx) - - execute_application_command_run(ctx) - - -def execute_application_command_build(ctx: ExecutionContext) -> None: - """Execute application command 'build'.""" - with ExitStack() as context_stack: - for manager in get_context_managers("build", ctx): - context_stack.enter_context(manager(ctx)) - - build_dir = ctx.build_dir() - recreate_directory(build_dir) - - build_commands = ctx.app.build_command( - "build", ctx.app_params, ctx.param_resolver - ) - execute_commands_locally(build_commands, build_dir) - - -def execute_commands_locally(commands: List[str], cwd: Path) -> None: - """Execute list of commands locally.""" - for command in commands: - print("Running: {}".format(command)) - run_and_wait( - command, cwd, terminate_on_error=True, out=sys.stdout, err=sys.stderr - ) - - -def execute_application_command_run(ctx: ExecutionContext) -> None: - """Execute application command.""" - assert ctx.system is not None, "System must be provided." - if ctx.is_deploy_needed and not ctx.system.supports_deploy: - raise ConfigurationException( - "System {} does not support data deploy".format(ctx.system.name) - ) - - with ExitStack() as context_stack: - for manager in get_context_managers("run", ctx): - context_stack.enter_context(manager(ctx)) - - print("Generating commands to execute") - commands_to_run = build_run_commands(ctx) - - if ctx.system.connectable: - establish_connection(ctx) - - if ctx.system.supports_deploy: - deploy_data(ctx) - - for command in commands_to_run: - print("Running: {}".format(command)) - exit_code, std_output, std_err = ctx.system.run(command) - - if exit_code != 0: - print("Application exited with exit code {}".format(exit_code)) - - if ctx.reporter: - ctx.reporter.parse(std_output) - std_output = ctx.reporter.get_filtered_output(std_output) - - print(std_output.decode("utf8"), end="") - print(std_err.decode("utf8"), end="") - - if ctx.reporter: - report = ctx.reporter.report(ctx) - ctx.reporter.save(report, cast(Path, ctx.report_file)) - - -def establish_connection( - ctx: ExecutionContext, retries: int = 90, interval: float = 15.0 -) -> None: - """Establish connection with the system.""" - assert ctx.system is not None, "System is required." - host, port = ctx.system.connection_details() - print( - "Trying to establish connection with '{}:{}' - " - "{} retries every {} seconds ".format(host, port, retries, interval), - end="", - ) - - try: - for _ in range(retries): - print(".", end="", flush=True) - - if ctx.system.establish_connection(): - break - - if isinstance(ctx.system, ControlledSystem) and not ctx.system.is_running(): - print( - "\n\n---------- {} execution failed ----------".format( - ctx.system.name - ) - ) - stdout, stderr = ctx.system.get_output() - print(stdout) - print(stderr) - - raise Exception("System is not running") - - wait(interval) - else: - raise ConnectionException("Couldn't connect to '{}:{}'.".format(host, port)) - finally: - print() - - -def wait(interval: float) -> None: - """Wait for a period of time.""" - time.sleep(interval) - - -def deploy_data(ctx: ExecutionContext) -> None: - """Deploy data to the system.""" - if isinstance(ctx.app, Application): - # Only application can deploy data (tools can not) - assert ctx.system is not None, "System is required." - for item in itertools.chain(ctx.app.get_deploy_data(), ctx.custom_deploy_data): - print("Deploying {} onto {}".format(item.src, item.dst)) - ctx.system.deploy(item.src, item.dst) - - -def build_run_commands(ctx: ExecutionContext) -> List[str]: - """Build commands to run application.""" - if isinstance(ctx.system, StandaloneSystem): - return ctx.system.build_command("run", ctx.system_params, ctx.param_resolver) - - return ctx.app.build_command("run", ctx.app_params, ctx.param_resolver) - - -@contextmanager -def controlled_system_manager(ctx: ExecutionContext) -> Generator[None, None, None]: - """Context manager used for system initialisation before run.""" - system = cast(ControlledSystem, ctx.system) - commands = system.build_command("run", ctx.system_params, ctx.param_resolver) - pid_file_path: Optional[Path] = None - if ctx.is_locking_required: - file_lock_path = get_file_lock_path(ctx) - pid_file_path = file_lock_path.parent / "{}.pid".format(file_lock_path.stem) - - system.start(commands, ctx.is_locking_required, pid_file_path) - try: - yield - finally: - print("Shutting down sequence...") - print("Stopping {}... (It could take few seconds)".format(system.name)) - system.stop(wait=True) - print("{} stopped successfully.".format(system.name)) - - -@contextmanager -def lock_execution_manager(ctx: ExecutionContext) -> Generator[None, None, None]: - """Lock execution manager.""" - file_lock_path = get_file_lock_path(ctx) - file_lock = FileLock(str(file_lock_path)) - - try: - file_lock.acquire(timeout=1) - except Timeout as error: - raise AnotherInstanceIsRunningException() from error - - try: - yield - finally: - file_lock.release() - - -def get_file_lock_path(ctx: ExecutionContext, lock_dir: Path = Path("/tmp")) -> Path: - """Get file lock path.""" - lock_modules = [] - if ctx.app.lock: - lock_modules.append(ctx.app.name) - if ctx.system is not None and ctx.system.lock: - lock_modules.append(ctx.system.name) - lock_filename = "" - if lock_modules: - lock_filename = "_".join(["middleware"] + lock_modules) + ".lock" - - if lock_filename: - lock_filename = resolve_all_parameters(lock_filename, ctx.param_resolver) - lock_filename = valid_for_filename(lock_filename) - - if not lock_filename: - raise ConfigurationException("No filename for lock provided") - - if not isinstance(lock_dir, Path) or not lock_dir.is_dir(): - raise ConfigurationException( - "Invalid directory {} for lock files provided".format(lock_dir) - ) - - return lock_dir / lock_filename - - -@contextmanager -def build_dir_manager(ctx: ExecutionContext) -> Generator[None, None, None]: - """Build directory manager.""" - try: - yield - finally: - if ( - ctx.is_build_required - and ctx.is_unique_build_dir_required - and ctx.build_dir().is_dir() - ): - remove_directory(ctx.build_dir()) - - -def get_context_managers( - command_name: str, ctx: ExecutionContext -) -> Sequence[Callable[[ExecutionContext], ContextManager[None]]]: - """Get context manager for the system.""" - managers = [] - - if ctx.is_locking_required: - managers.append(lock_execution_manager) - - if command_name == "run": - if isinstance(ctx.system, ControlledSystem): - managers.append(controlled_system_manager) - - return managers - - -def get_tool_by_system(tool_name: str, system_name: Optional[str]) -> Tool: - """Return tool (optionally by provided system name.""" - tools = get_tool(tool_name, system_name) - if not tools: - raise ConfigurationException( - "Tool '{}' not found or doesn't support the system '{}'".format( - tool_name, system_name - ) - ) - if len(tools) != 1: - raise ConfigurationException( - "Please specify the system for tool {}.".format(tool_name) - ) - tool = tools[0] - - return tool - - -def execute_tool_command( - tool_name: str, - tool_params: List[str], - system_name: Optional[str] = None, -) -> None: - """Execute the tool command locally calling the 'run' command.""" - tool = get_tool_by_system(tool_name, system_name) - ctx = ExecutionContext( - app=tool, app_params=tool_params, system=None, system_params=[] - ) - commands = tool.build_command("run", tool_params, ctx.param_resolver) - - execute_commands_locally(commands, Path.cwd()) diff --git a/src/aiet/backend/output_parser.py b/src/aiet/backend/output_parser.py deleted file mode 100644 index 111772a..0000000 --- a/src/aiet/backend/output_parser.py +++ /dev/null @@ -1,176 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Definition of output parsers (including base class OutputParser).""" -import base64 -import json -import re -from abc import ABC -from abc import abstractmethod -from typing import Any -from typing import Dict -from typing import Union - - -class OutputParser(ABC): - """Abstract base class for output parsers.""" - - def __init__(self, name: str) -> None: - """Set up the name of the parser.""" - super().__init__() - self.name = name - - @abstractmethod - def __call__(self, output: bytearray) -> Dict[str, Any]: - """Parse the output and return a map of names to metrics.""" - return {} - - # pylint: disable=no-self-use - def filter_out_parsed_content(self, output: bytearray) -> bytearray: - """ - Filter out the parsed content from the output. - - Does nothing by default. Can be overridden in subclasses. - """ - return output - - -class RegexOutputParser(OutputParser): - """Parser of standard output data using regular expressions.""" - - _TYPE_MAP = {"str": str, "float": float, "int": int} - - def __init__( - self, - name: str, - regex_config: Dict[str, Dict[str, str]], - ) -> None: - """ - Set up the parser with the regular expressions. - - The regex_config is mapping from a name to a dict with keys 'pattern' - and 'type': - - The 'pattern' holds the regular expression that must contain exactly - one capturing parenthesis - - The 'type' can be one of ['str', 'float', 'int']. - - Example: - ``` - {"Metric1": {"pattern": ".*= *(.*)", "type": "str"}} - ``` - - The different regular expressions from the config are combined using - non-capturing parenthesis, i.e. regular expressions must not overlap - if more than one match per line is expected. - """ - super().__init__(name) - - self._verify_config(regex_config) - self._regex_cfg = regex_config - - # Compile regular expression to match in the output - self._regex = re.compile( - "|".join("(?:{0})".format(x["pattern"]) for x in self._regex_cfg.values()) - ) - - def __call__(self, output: bytearray) -> Dict[str, Union[str, float, int]]: - """ - Parse the output and return a map of names to metrics. - - Example: - Assuming a regex_config as used as example in `__init__()` and the - following output: - ``` - Simulation finished: - SIMULATION_STATUS = SUCCESS - Simulation DONE - ``` - Then calling the parser should return the following dict: - ``` - { - "Metric1": "SUCCESS" - } - ``` - """ - metrics = {} - output_str = output.decode("utf-8") - results = self._regex.findall(output_str) - for line_result in results: - for idx, (name, cfg) in enumerate(self._regex_cfg.items()): - # The result(s) returned by findall() are either a single string - # or a tuple (depending on the number of groups etc.) - result = ( - line_result if isinstance(line_result, str) else line_result[idx] - ) - if result: - mapped_result = self._TYPE_MAP[cfg["type"]](result) - metrics[name] = mapped_result - return metrics - - def _verify_config(self, regex_config: Dict[str, Dict[str, str]]) -> None: - """Make sure we have a valid regex_config. - - I.e. - - Exactly one capturing parenthesis per pattern - - Correct types - """ - for name, cfg in regex_config.items(): - # Check that there is one capturing group defined in the pattern. - regex = re.compile(cfg["pattern"]) - if regex.groups != 1: - raise ValueError( - f"Pattern for metric '{name}' must have exactly one " - f"capturing parenthesis, but it has {regex.groups}." - ) - # Check if type is supported - if not cfg["type"] in self._TYPE_MAP: - raise TypeError( - f"Type '{cfg['type']}' for metric '{name}' is not " - f"supported. Choose from: {list(self._TYPE_MAP.keys())}." - ) - - -class Base64OutputParser(OutputParser): - """ - Parser to extract base64-encoded JSON from tagged standard output. - - Example of the tagged output: - ``` - # Encoded JSON: {"test": 1234} - <metrics>eyJ0ZXN0IjogMTIzNH0</metrics> - ``` - """ - - TAG_NAME = "metrics" - - def __init__(self, name: str) -> None: - """Set up the regular expression to extract tagged strings.""" - super().__init__(name) - self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)</{self.TAG_NAME}>") - - def __call__(self, output: bytearray) -> Dict[str, Any]: - """ - Parse the output and return a map of index (as string) to decoded JSON. - - Example: - Using the tagged output from the class docs the parser should return - the following dict: - ``` - { - "0": {"test": 1234} - } - ``` - """ - metrics = {} - output_str = output.decode("utf-8") - results = self._regex.findall(output_str) - for idx, result_base64 in enumerate(results): - result_json = base64.b64decode(result_base64, validate=True) - result = json.loads(result_json) - metrics[str(idx)] = result - - return metrics - - def filter_out_parsed_content(self, output: bytearray) -> bytearray: - """Filter out base64-encoded content from the output.""" - output_str = self._regex.sub("", output.decode("utf-8")) - return bytearray(output_str.encode("utf-8")) diff --git a/src/aiet/backend/protocol.py b/src/aiet/backend/protocol.py deleted file mode 100644 index c621436..0000000 --- a/src/aiet/backend/protocol.py +++ /dev/null @@ -1,325 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain protocol related classes and functions.""" -from abc import ABC -from abc import abstractmethod -from contextlib import closing -from pathlib import Path -from typing import Any -from typing import cast -from typing import Iterable -from typing import Optional -from typing import Tuple -from typing import Union - -import paramiko - -from aiet.backend.common import ConfigurationException -from aiet.backend.config import LocalProtocolConfig -from aiet.backend.config import SSHConfig -from aiet.utils.proc import run_and_wait - - -# Redirect all paramiko thread exceptions to a file otherwise these will be -# printed to stderr. -paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO) - - -class SSHConnectionException(Exception): - """SSH connection exception.""" - - -class SupportsClose(ABC): - """Class indicates support of close operation.""" - - @abstractmethod - def close(self) -> None: - """Close protocol session.""" - - -class SupportsDeploy(ABC): - """Class indicates support of deploy operation.""" - - @abstractmethod - def deploy(self, src: Path, dst: str, retry: bool = True) -> None: - """Abstract method for deploy data.""" - - -class SupportsConnection(ABC): - """Class indicates that protocol uses network connections.""" - - @abstractmethod - def establish_connection(self) -> bool: - """Establish connection with underlying system.""" - - @abstractmethod - def connection_details(self) -> Tuple[str, int]: - """Return connection details (host, port).""" - - -class Protocol(ABC): - """Abstract class for representing the protocol.""" - - def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: - """Initialize the class using a dict.""" - self.__dict__.update(iterable, **kwargs) - self._validate() - - @abstractmethod - def _validate(self) -> None: - """Abstract method for config data validation.""" - - @abstractmethod - def run( - self, command: str, retry: bool = False - ) -> Tuple[int, bytearray, bytearray]: - """ - Abstract method for running commands. - - Returns a tuple: (exit_code, stdout, stderr) - """ - - -class CustomSFTPClient(paramiko.SFTPClient): - """Class for creating a custom sftp client.""" - - def put_dir(self, source: Path, target: str) -> None: - """Upload the source directory to the target path. - - The target directory needs to exists and the last directory of the - source will be created under the target with all its content. - """ - # Create the target directory - self._mkdir(target, ignore_existing=True) - # Create the last directory in the source on the target - self._mkdir("{}/{}".format(target, source.name), ignore_existing=True) - # Go through the whole content of source - for item in sorted(source.glob("**/*")): - relative_path = item.relative_to(source.parent) - remote_target = target / relative_path - if item.is_file(): - self.put(str(item), str(remote_target)) - else: - self._mkdir(str(remote_target), ignore_existing=True) - - def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None: - """Extend mkdir functionality. - - This version adds an option to not fail if the folder exists. - """ - try: - super().mkdir(path, mode) - except IOError as error: - if ignore_existing: - pass - else: - raise error - - -class LocalProtocol(Protocol): - """Class for local protocol.""" - - protocol: str - cwd: Path - - def run( - self, command: str, retry: bool = False - ) -> Tuple[int, bytearray, bytearray]: - """ - Run command locally. - - Returns a tuple: (exit_code, stdout, stderr) - """ - if not isinstance(self.cwd, Path) or not self.cwd.is_dir(): - raise ConfigurationException("Wrong working directory {}".format(self.cwd)) - - stdout = bytearray() - stderr = bytearray() - - return run_and_wait( - command, self.cwd, terminate_on_error=True, out=stdout, err=stderr - ) - - def _validate(self) -> None: - """Validate protocol configuration.""" - assert hasattr(self, "protocol") and self.protocol == "local" - assert hasattr(self, "cwd") - - -class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection): - """Class for SSH protocol.""" - - protocol: str - username: str - password: str - hostname: str - port: int - - def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: - """Initialize the class using a dict.""" - super().__init__(iterable, **kwargs) - # Internal state to store if the system is connectable. It will be set - # to true at the first connection instance - self.client: Optional[paramiko.client.SSHClient] = None - self.port = int(self.port) - - def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: - """ - Run command over SSH. - - Returns a tuple: (exit_code, stdout, stderr) - """ - transport = self._get_transport() - with closing(transport.open_session()) as channel: - # Enable shell's .profile settings and execute command - channel.exec_command("bash -l -c '{}'".format(command)) - exit_status = -1 - stdout = bytearray() - stderr = bytearray() - while True: - if channel.exit_status_ready(): - exit_status = channel.recv_exit_status() - # Call it one last time to read any leftover in the channel - self._recv_stdout_err(channel, stdout, stderr) - break - self._recv_stdout_err(channel, stdout, stderr) - - return exit_status, stdout, stderr - - def deploy(self, src: Path, dst: str, retry: bool = True) -> None: - """Deploy src to remote dst over SSH. - - src and dst should be path to a file or directory. - """ - transport = self._get_transport() - sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport)) - - with closing(sftp): - if src.is_dir(): - sftp.put_dir(src, dst) - elif src.is_file(): - sftp.put(str(src), dst) - else: - raise Exception("Deploy error: file type not supported") - - # After the deployment of files, sync the remote filesystem to flush - # buffers to hard disk - self.run("sync") - - def close(self) -> None: - """Close protocol session.""" - if self.client is not None: - print("Try syncing remote file system...") - # Before stopping the system, we try to run sync to make sure all - # data are flushed on disk. - self.run("sync", retry=False) - self._close_client(self.client) - - def establish_connection(self) -> bool: - """Establish connection with underlying system.""" - if self.client is not None: - return True - - self.client = self._connect() - return self.client is not None - - def _get_transport(self) -> paramiko.transport.Transport: - """Get transport.""" - self.establish_connection() - - if self.client is None: - raise SSHConnectionException( - "Couldn't connect to '{}:{}'.".format(self.hostname, self.port) - ) - - transport = self.client.get_transport() - if not transport: - raise Exception("Unable to get transport") - - return transport - - def connection_details(self) -> Tuple[str, int]: - """Return connection details of underlying system.""" - return (self.hostname, self.port) - - def _connect(self) -> Optional[paramiko.client.SSHClient]: - """Try to establish connection.""" - client: Optional[paramiko.client.SSHClient] = None - try: - client = paramiko.client.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect( - self.hostname, - self.port, - self.username, - self.password, - # next parameters should be set to False to disable authentication - # using ssh keys - allow_agent=False, - look_for_keys=False, - ) - return client - except ( - # OSError raised on first attempt to connect when running inside Docker - OSError, - paramiko.ssh_exception.NoValidConnectionsError, - paramiko.ssh_exception.SSHException, - ): - # even if connection is not established socket could be still - # open, it should be closed - self._close_client(client) - - return None - - @staticmethod - def _close_client(client: Optional[paramiko.client.SSHClient]) -> None: - """Close ssh client.""" - try: - if client is not None: - client.close() - except Exception: # pylint: disable=broad-except - pass - - @classmethod - def _recv_stdout_err( - cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray - ) -> None: - """Read from channel to stdout/stder.""" - chunk_size = 512 - if channel.recv_ready(): - stdout_chunk = channel.recv(chunk_size) - stdout.extend(stdout_chunk) - if channel.recv_stderr_ready(): - stderr_chunk = channel.recv_stderr(chunk_size) - stderr.extend(stderr_chunk) - - def _validate(self) -> None: - """Check if there are all the info for establishing the connection.""" - assert hasattr(self, "protocol") and self.protocol == "ssh" - assert hasattr(self, "username") - assert hasattr(self, "password") - assert hasattr(self, "hostname") - assert hasattr(self, "port") - - -class ProtocolFactory: - """Factory class to return the appropriate Protocol class.""" - - @staticmethod - def get_protocol( - config: Optional[Union[SSHConfig, LocalProtocolConfig]], - **kwargs: Union[str, Path, None] - ) -> Union[SSHProtocol, LocalProtocol]: - """Return the right protocol instance based on the config.""" - if not config: - raise ValueError("No protocol config provided") - - protocol = config["protocol"] - if protocol == "ssh": - return SSHProtocol(config) - - if protocol == "local": - cwd = kwargs.get("cwd") - return LocalProtocol(config, cwd=cwd) - - raise ValueError("Protocol not supported: '{}'".format(protocol)) diff --git a/src/aiet/backend/source.py b/src/aiet/backend/source.py deleted file mode 100644 index dec175a..0000000 --- a/src/aiet/backend/source.py +++ /dev/null @@ -1,209 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain source related classes and functions.""" -import os -import shutil -import tarfile -from abc import ABC -from abc import abstractmethod -from pathlib import Path -from tarfile import TarFile -from typing import Optional -from typing import Union - -from aiet.backend.common import AIET_CONFIG_FILE -from aiet.backend.common import ConfigurationException -from aiet.backend.common import get_backend_config -from aiet.backend.common import is_backend_directory -from aiet.backend.common import load_config -from aiet.backend.config import BackendConfig -from aiet.utils.fs import copy_directory_content - - -class Source(ABC): - """Source class.""" - - @abstractmethod - def name(self) -> Optional[str]: - """Get source name.""" - - @abstractmethod - def config(self) -> Optional[BackendConfig]: - """Get configuration file content.""" - - @abstractmethod - def install_into(self, destination: Path) -> None: - """Install source into destination directory.""" - - @abstractmethod - def create_destination(self) -> bool: - """Return True if destination folder should be created before installation.""" - - -class DirectorySource(Source): - """DirectorySource class.""" - - def __init__(self, directory_path: Path) -> None: - """Create the DirectorySource instance.""" - assert isinstance(directory_path, Path) - self.directory_path = directory_path - - def name(self) -> str: - """Return name of source.""" - return self.directory_path.name - - def config(self) -> Optional[BackendConfig]: - """Return configuration file content.""" - if not is_backend_directory(self.directory_path): - raise ConfigurationException("No configuration file found") - - config_file = get_backend_config(self.directory_path) - return load_config(config_file) - - def install_into(self, destination: Path) -> None: - """Install source into destination directory.""" - if not destination.is_dir(): - raise ConfigurationException("Wrong destination {}".format(destination)) - - if not self.directory_path.is_dir(): - raise ConfigurationException( - "Directory {} does not exist".format(self.directory_path) - ) - - copy_directory_content(self.directory_path, destination) - - def create_destination(self) -> bool: - """Return True if destination folder should be created before installation.""" - return True - - -class TarArchiveSource(Source): - """TarArchiveSource class.""" - - def __init__(self, archive_path: Path) -> None: - """Create the TarArchiveSource class.""" - assert isinstance(archive_path, Path) - self.archive_path = archive_path - self._config: Optional[BackendConfig] = None - self._has_top_level_folder: Optional[bool] = None - self._name: Optional[str] = None - - def _read_archive_content(self) -> None: - """Read various information about archive.""" - # get source name from archive name (everything without extensions) - extensions = "".join(self.archive_path.suffixes) - self._name = self.archive_path.name.rstrip(extensions) - - if not self.archive_path.exists(): - return - - with self._open(self.archive_path) as archive: - try: - config_entry = archive.getmember(AIET_CONFIG_FILE) - self._has_top_level_folder = False - except KeyError as error_no_config: - try: - archive_entries = archive.getnames() - entries_common_prefix = os.path.commonprefix(archive_entries) - top_level_dir = entries_common_prefix.rstrip("/") - - if not top_level_dir: - raise RuntimeError( - "Archive has no top level directory" - ) from error_no_config - - config_path = "{}/{}".format(top_level_dir, AIET_CONFIG_FILE) - - config_entry = archive.getmember(config_path) - self._has_top_level_folder = True - self._name = top_level_dir - except (KeyError, RuntimeError) as error_no_root_dir_or_config: - raise ConfigurationException( - "No configuration file found" - ) from error_no_root_dir_or_config - - content = archive.extractfile(config_entry) - self._config = load_config(content) - - def config(self) -> Optional[BackendConfig]: - """Return configuration file content.""" - if self._config is None: - self._read_archive_content() - - return self._config - - def name(self) -> Optional[str]: - """Return name of the source.""" - if self._name is None: - self._read_archive_content() - - return self._name - - def create_destination(self) -> bool: - """Return True if destination folder must be created before installation.""" - if self._has_top_level_folder is None: - self._read_archive_content() - - return not self._has_top_level_folder - - def install_into(self, destination: Path) -> None: - """Install source into destination directory.""" - if not destination.is_dir(): - raise ConfigurationException("Wrong destination {}".format(destination)) - - with self._open(self.archive_path) as archive: - archive.extractall(destination) - - def _open(self, archive_path: Path) -> TarFile: - """Open archive file.""" - if not archive_path.is_file(): - raise ConfigurationException("File {} does not exist".format(archive_path)) - - if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"): - mode = "r:gz" - else: - raise ConfigurationException( - "Unsupported archive type {}".format(archive_path) - ) - - # The returned TarFile object can be used as a context manager (using - # 'with') by the calling instance. - return tarfile.open( # pylint: disable=consider-using-with - self.archive_path, mode=mode - ) - - -def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]: - """Return appropriate source instance based on provided source path.""" - if source_path.is_file(): - return TarArchiveSource(source_path) - - if source_path.is_dir(): - return DirectorySource(source_path) - - raise ConfigurationException("Unable to read {}".format(source_path)) - - -def create_destination_and_install(source: Source, resource_path: Path) -> None: - """Create destination directory and install source. - - This function is used for actual installation of system/backend New - directory will be created inside :resource_path: if needed If for example - archive contains top level folder then no need to create new directory - """ - destination = resource_path - create_destination = source.create_destination() - - if create_destination: - name = source.name() - if not name: - raise ConfigurationException("Unable to get source name") - - destination = resource_path / name - destination.mkdir() - try: - source.install_into(destination) - except Exception as error: - if create_destination: - shutil.rmtree(destination) - raise error diff --git a/src/aiet/backend/system.py b/src/aiet/backend/system.py deleted file mode 100644 index 48f1bb1..0000000 --- a/src/aiet/backend/system.py +++ /dev/null @@ -1,289 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""System backend module.""" -from pathlib import Path -from typing import Any -from typing import cast -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - -from aiet.backend.common import Backend -from aiet.backend.common import ConfigurationException -from aiet.backend.common import get_backend_configs -from aiet.backend.common import get_backend_directories -from aiet.backend.common import load_config -from aiet.backend.common import remove_backend -from aiet.backend.config import SystemConfig -from aiet.backend.controller import SystemController -from aiet.backend.controller import SystemControllerSingleInstance -from aiet.backend.protocol import ProtocolFactory -from aiet.backend.protocol import SupportsClose -from aiet.backend.protocol import SupportsConnection -from aiet.backend.protocol import SupportsDeploy -from aiet.backend.source import create_destination_and_install -from aiet.backend.source import get_source -from aiet.utils.fs import get_resources - - -def get_available_systems_directory_names() -> List[str]: - """Return a list of directory names for all avialable systems.""" - return [entry.name for entry in get_backend_directories("systems")] - - -def get_available_systems() -> List["System"]: - """Return a list with all available systems.""" - available_systems = [] - for config_json in get_backend_configs("systems"): - config_entries = cast(List[SystemConfig], (load_config(config_json))) - for config_entry in config_entries: - config_entry["config_location"] = config_json.parent.absolute() - system = load_system(config_entry) - available_systems.append(system) - - return sorted(available_systems, key=lambda system: system.name) - - -def get_system(system_name: str) -> Optional["System"]: - """Return a system instance with the same name passed as argument.""" - available_systems = get_available_systems() - for system in available_systems: - if system_name == system.name: - return system - return None - - -def install_system(source_path: Path) -> None: - """Install new system.""" - try: - source = get_source(source_path) - config = cast(List[SystemConfig], source.config()) - systems_to_install = [load_system(entry) for entry in config] - except Exception as error: - raise ConfigurationException("Unable to read system definition") from error - - if not systems_to_install: - raise ConfigurationException("No system definition found") - - available_systems = get_available_systems() - already_installed = [s for s in systems_to_install if s in available_systems] - if already_installed: - names = [system.name for system in already_installed] - raise ConfigurationException( - "Systems [{}] are already installed".format(",".join(names)) - ) - - create_destination_and_install(source, get_resources("systems")) - - -def remove_system(directory_name: str) -> None: - """Remove system.""" - remove_backend(directory_name, "systems") - - -class System(Backend): - """System class.""" - - def __init__(self, config: SystemConfig) -> None: - """Construct the System class using the dictionary passed.""" - super().__init__(config) - - self._setup_data_transfer(config) - self._setup_reporting(config) - - def _setup_data_transfer(self, config: SystemConfig) -> None: - data_transfer_config = config.get("data_transfer") - protocol = ProtocolFactory().get_protocol( - data_transfer_config, cwd=self.config_location - ) - self.protocol = protocol - - def _setup_reporting(self, config: SystemConfig) -> None: - self.reporting = config.get("reporting") - - def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: - """ - Run command on the system. - - Returns a tuple: (exit_code, stdout, stderr) - """ - return self.protocol.run(command, retry) - - def deploy(self, src: Path, dst: str, retry: bool = True) -> None: - """Deploy files to the system.""" - if isinstance(self.protocol, SupportsDeploy): - self.protocol.deploy(src, dst, retry) - - @property - def supports_deploy(self) -> bool: - """Check if protocol supports deploy operation.""" - return isinstance(self.protocol, SupportsDeploy) - - @property - def connectable(self) -> bool: - """Check if protocol supports connection.""" - return isinstance(self.protocol, SupportsConnection) - - def establish_connection(self) -> bool: - """Establish connection with the system.""" - if not isinstance(self.protocol, SupportsConnection): - raise ConfigurationException( - "System {} does not support connections".format(self.name) - ) - - return self.protocol.establish_connection() - - def connection_details(self) -> Tuple[str, int]: - """Return connection details.""" - if not isinstance(self.protocol, SupportsConnection): - raise ConfigurationException( - "System {} does not support connections".format(self.name) - ) - - return self.protocol.connection_details() - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, System): - return False - - return super().__eq__(other) and self.name == other.name - - def get_details(self) -> Dict[str, Any]: - """Return a dictionary with all relevant information of a System.""" - output = { - "type": "system", - "name": self.name, - "description": self.description, - "data_transfer_protocol": self.protocol.protocol, - "commands": self._get_command_details(), - "annotations": self.annotations, - } - - return output - - -class StandaloneSystem(System): - """StandaloneSystem class.""" - - -def get_controller( - single_instance: bool, pid_file_path: Optional[Path] = None -) -> SystemController: - """Get system controller.""" - if single_instance: - return SystemControllerSingleInstance(pid_file_path) - - return SystemController() - - -class ControlledSystem(System): - """ControlledSystem class.""" - - def __init__(self, config: SystemConfig): - """Construct the ControlledSystem class using the dictionary passed.""" - super().__init__(config) - self.controller: Optional[SystemController] = None - - def start( - self, - commands: List[str], - single_instance: bool = True, - pid_file_path: Optional[Path] = None, - ) -> None: - """Launch the system.""" - if ( - not isinstance(self.config_location, Path) - or not self.config_location.is_dir() - ): - raise ConfigurationException( - "System {} has wrong config location".format(self.name) - ) - - self.controller = get_controller(single_instance, pid_file_path) - self.controller.start(commands, self.config_location) - - def is_running(self) -> bool: - """Check if system is running.""" - if not self.controller: - return False - - return self.controller.is_running() - - def get_output(self) -> Tuple[str, str]: - """Return system output.""" - if not self.controller: - return "", "" - - return self.controller.get_output() - - def stop(self, wait: bool = False) -> None: - """Stop the system.""" - if not self.controller: - raise Exception("System has not been started") - - if isinstance(self.protocol, SupportsClose): - try: - self.protocol.close() - except Exception as error: # pylint: disable=broad-except - print(error) - self.controller.stop(wait) - - -def load_system(config: SystemConfig) -> Union[StandaloneSystem, ControlledSystem]: - """Load system based on it's execution type.""" - data_transfer = config.get("data_transfer", {}) - protocol = data_transfer.get("protocol") - populate_shared_params(config) - - if protocol == "ssh": - return ControlledSystem(config) - - if protocol == "local": - return StandaloneSystem(config) - - raise ConfigurationException( - "Unsupported execution type for protocol {}".format(protocol) - ) - - -def populate_shared_params(config: SystemConfig) -> None: - """Populate command parameters with shared parameters.""" - user_params = config.get("user_params") - if not user_params or "shared" not in user_params: - return - - shared_user_params = user_params["shared"] - if not shared_user_params: - return - - only_aliases = all(p.get("alias") for p in shared_user_params) - if not only_aliases: - raise ConfigurationException("All shared parameters should have aliases") - - commands = config.get("commands", {}) - for cmd_name in ["build", "run"]: - command = commands.get(cmd_name) - if command is None: - commands[cmd_name] = [] - cmd_user_params = user_params.get(cmd_name) - if not cmd_user_params: - cmd_user_params = shared_user_params - else: - only_aliases = all(p.get("alias") for p in cmd_user_params) - if not only_aliases: - raise ConfigurationException( - "All parameters for command {} should have aliases".format(cmd_name) - ) - merged_by_alias = { - **{p.get("alias"): p for p in shared_user_params}, - **{p.get("alias"): p for p in cmd_user_params}, - } - cmd_user_params = list(merged_by_alias.values()) - - user_params[cmd_name] = cmd_user_params - - config["commands"] = commands - del user_params["shared"] diff --git a/src/aiet/backend/tool.py b/src/aiet/backend/tool.py deleted file mode 100644 index d643665..0000000 --- a/src/aiet/backend/tool.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tool backend module.""" -from typing import Any -from typing import cast -from typing import Dict -from typing import List -from typing import Optional - -from aiet.backend.common import Backend -from aiet.backend.common import ConfigurationException -from aiet.backend.common import get_backend_configs -from aiet.backend.common import get_backend_directories -from aiet.backend.common import load_application_or_tool_configs -from aiet.backend.common import load_config -from aiet.backend.config import ExtendedToolConfig -from aiet.backend.config import ToolConfig - - -def get_available_tool_directory_names() -> List[str]: - """Return a list of directory names for all available tools.""" - return [entry.name for entry in get_backend_directories("tools")] - - -def get_available_tools() -> List["Tool"]: - """Return a list with all available tools.""" - available_tools = [] - for config_json in get_backend_configs("tools"): - config_entries = cast(List[ExtendedToolConfig], load_config(config_json)) - for config_entry in config_entries: - config_entry["config_location"] = config_json.parent.absolute() - tools = load_tools(config_entry) - available_tools += tools - - return sorted(available_tools, key=lambda tool: tool.name) - - -def get_tool(tool_name: str, system_name: Optional[str] = None) -> List["Tool"]: - """Return a tool instance with the same name passed as argument.""" - return [ - tool - for tool in get_available_tools() - if tool.name == tool_name and (not system_name or tool.can_run_on(system_name)) - ] - - -def get_unique_tool_names(system_name: Optional[str] = None) -> List[str]: - """Extract a list of unique tool names of all tools available.""" - return list( - set( - tool.name - for tool in get_available_tools() - if not system_name or tool.can_run_on(system_name) - ) - ) - - -class Tool(Backend): - """Class for representing a single tool component.""" - - def __init__(self, config: ToolConfig) -> None: - """Construct a Tool instance from a dict.""" - super().__init__(config) - - self.supported_systems = config.get("supported_systems", []) - - if "run" not in self.commands: - raise ConfigurationException("A Tool must have a 'run' command.") - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Tool): - return False - - return ( - super().__eq__(other) - and self.name == other.name - and set(self.supported_systems) == set(other.supported_systems) - ) - - def can_run_on(self, system_name: str) -> bool: - """Check if the tool can run on the system passed as argument.""" - return system_name in self.supported_systems - - def get_details(self) -> Dict[str, Any]: - """Return dictionary with all relevant information of the Tool instance.""" - output = { - "type": "tool", - "name": self.name, - "description": self.description, - "supported_systems": self.supported_systems, - "commands": self._get_command_details(), - } - - return output - - -def load_tools(config: ExtendedToolConfig) -> List[Tool]: - """Load tool. - - Tool configuration could contain different parameters/commands for different - supported systems. For each supported system this function will return separate - Tool instance with appropriate configuration. - """ - configs = load_application_or_tool_configs( - config, ToolConfig, is_system_required=False - ) - tools = [Tool(cfg) for cfg in configs] - return tools diff --git a/src/aiet/cli/__init__.py b/src/aiet/cli/__init__.py deleted file mode 100644 index bcd17c3..0000000 --- a/src/aiet/cli/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module to mange the CLI interface.""" -import click - -from aiet import __version__ -from aiet.cli.application import application_cmd -from aiet.cli.completion import completion_cmd -from aiet.cli.system import system_cmd -from aiet.cli.tool import tool_cmd -from aiet.utils.helpers import set_verbosity - - -@click.group() -@click.version_option(__version__) -@click.option( - "-v", "--verbose", default=0, count=True, callback=set_verbosity, expose_value=False -) -@click.pass_context -def cli(ctx: click.Context) -> None: # pylint: disable=unused-argument - """AIET: AI Evaluation Toolkit.""" - # Unused arguments must be present here in definition to pass click context. - - -cli.add_command(application_cmd) -cli.add_command(system_cmd) -cli.add_command(tool_cmd) -cli.add_command(completion_cmd) diff --git a/src/aiet/cli/application.py b/src/aiet/cli/application.py deleted file mode 100644 index 59b652d..0000000 --- a/src/aiet/cli/application.py +++ /dev/null @@ -1,362 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-FileCopyrightText: Copyright (c) 2021, Gianluca Gippetto. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause -"""Module to manage the CLI interface of applications.""" -import json -import logging -import re -from pathlib import Path -from typing import Any -from typing import IO -from typing import List -from typing import Optional -from typing import Tuple - -import click -import cloup - -from aiet.backend.application import get_application -from aiet.backend.application import get_available_application_directory_names -from aiet.backend.application import get_unique_application_names -from aiet.backend.application import install_application -from aiet.backend.application import remove_application -from aiet.backend.common import DataPaths -from aiet.backend.execution import execute_application_command -from aiet.backend.execution import run_application -from aiet.backend.system import get_available_systems -from aiet.cli.common import get_format -from aiet.cli.common import middleware_exception_handler -from aiet.cli.common import middleware_signal_handler -from aiet.cli.common import print_command_details -from aiet.cli.common import set_format - - -@click.group(name="application") -@click.option( - "-f", - "--format", - "format_", - type=click.Choice(["cli", "json"]), - default="cli", - show_default=True, -) -@click.pass_context -def application_cmd(ctx: click.Context, format_: str) -> None: - """Sub command to manage applications.""" - set_format(ctx, format_) - - -@application_cmd.command(name="list") -@click.pass_context -@click.option( - "-s", - "--system", - "system_name", - type=click.Choice([s.name for s in get_available_systems()]), - required=False, -) -def list_cmd(ctx: click.Context, system_name: str) -> None: - """List all available applications.""" - unique_application_names = get_unique_application_names(system_name) - unique_application_names.sort() - if get_format(ctx) == "json": - data = {"type": "application", "available": unique_application_names} - print(json.dumps(data)) - else: - print("Available applications:\n") - print(*unique_application_names, sep="\n") - - -@application_cmd.command(name="details") -@click.option( - "-n", - "--name", - "application_name", - type=click.Choice(get_unique_application_names()), - required=True, -) -@click.option( - "-s", - "--system", - "system_name", - type=click.Choice([s.name for s in get_available_systems()]), - required=False, -) -@click.pass_context -def details_cmd(ctx: click.Context, application_name: str, system_name: str) -> None: - """Details of a specific application.""" - applications = get_application(application_name, system_name) - if not applications: - raise click.UsageError( - "Application '{}' doesn't support the system '{}'".format( - application_name, system_name - ) - ) - - if get_format(ctx) == "json": - applications_details = [s.get_details() for s in applications] - print(json.dumps(applications_details)) - else: - for application in applications: - application_details = application.get_details() - application_details_template = ( - 'Application "{name}" details\nDescription: {description}' - ) - - print( - application_details_template.format( - name=application_details["name"], - description=application_details["description"], - ) - ) - - print( - "\nSupported systems: {}".format( - ", ".join(application_details["supported_systems"]) - ) - ) - - command_details = application_details["commands"] - - for command, details in command_details.items(): - print("\n{} commands:".format(command)) - print_command_details(details) - - -# pylint: disable=too-many-arguments -@application_cmd.command(name="execute") -@click.option( - "-n", - "--name", - "application_name", - type=click.Choice(get_unique_application_names()), - required=True, -) -@click.option( - "-s", - "--system", - "system_name", - type=click.Choice([s.name for s in get_available_systems()]), - required=True, -) -@click.option( - "-c", - "--command", - "command_name", - type=click.Choice(["build", "run"]), - required=True, -) -@click.option("-p", "--param", "application_params", multiple=True) -@click.option("--system-param", "system_params", multiple=True) -@click.option("-d", "--deploy", "deploy_params", multiple=True) -@middleware_signal_handler -@middleware_exception_handler -def execute_cmd( - application_name: str, - system_name: str, - command_name: str, - application_params: List[str], - system_params: List[str], - deploy_params: List[str], -) -> None: - """Execute application commands. DEPRECATED! Use 'aiet application run' instead.""" - logging.warning( - "Please use 'aiet application run' instead. Use of 'aiet application " - "execute' is deprecated and might be removed in a future release." - ) - - custom_deploy_data = get_custom_deploy_data(command_name, deploy_params) - - execute_application_command( - command_name, - application_name, - application_params, - system_name, - system_params, - custom_deploy_data, - ) - - -@cloup.command(name="run") -@cloup.option( - "-n", - "--name", - "application_name", - type=click.Choice(get_unique_application_names()), -) -@cloup.option( - "-s", - "--system", - "system_name", - type=click.Choice([s.name for s in get_available_systems()]), -) -@cloup.option("-p", "--param", "application_params", multiple=True) -@cloup.option("--system-param", "system_params", multiple=True) -@cloup.option("-d", "--deploy", "deploy_params", multiple=True) -@click.option( - "-r", - "--report", - "report_file", - type=Path, - help="Create a report file in JSON format containing metrics parsed from " - "the simulation output as specified in the aiet-config.json.", -) -@cloup.option( - "--config", - "config_file", - type=click.File("r"), - help="Read options from a config file rather than from the command line. " - "The config file is a json file.", -) -@cloup.constraint( - cloup.constraints.If( - cloup.constraints.conditions.Not( - cloup.constraints.conditions.IsSet("config_file") - ), - then=cloup.constraints.require_all, - ), - ["system_name", "application_name"], -) -@cloup.constraint( - cloup.constraints.If("config_file", then=cloup.constraints.accept_none), - [ - "system_name", - "application_name", - "application_params", - "system_params", - "deploy_params", - ], -) -@middleware_signal_handler -@middleware_exception_handler -def run_cmd( - application_name: str, - system_name: str, - application_params: List[str], - system_params: List[str], - deploy_params: List[str], - report_file: Optional[Path], - config_file: Optional[IO[str]], -) -> None: - """Execute application commands.""" - if config_file: - payload_data = json.load(config_file) - ( - system_name, - application_name, - application_params, - system_params, - deploy_params, - report_file, - ) = parse_payload_run_config(payload_data) - - custom_deploy_data = get_custom_deploy_data("run", deploy_params) - - run_application( - application_name, - application_params, - system_name, - system_params, - custom_deploy_data, - report_file, - ) - - -application_cmd.add_command(run_cmd) - - -def parse_payload_run_config( - payload_data: dict, -) -> Tuple[str, str, List[str], List[str], List[str], Optional[Path]]: - """Parse the payload into a tuple.""" - system_id = payload_data.get("id") - arguments: Optional[Any] = payload_data.get("arguments") - - if not isinstance(system_id, str): - raise click.ClickException("invalid payload json: no system 'id'") - if not isinstance(arguments, dict): - raise click.ClickException("invalid payload json: no arguments object") - - application_name = arguments.pop("application", None) - if not isinstance(application_name, str): - raise click.ClickException("invalid payload json: no application_id") - - report_path = arguments.pop("report_path", None) - - application_params = [] - system_params = [] - deploy_params = [] - - for (param_key, value) in arguments.items(): - (par, _) = re.subn("^application/", "", param_key) - (par, found_sys_param) = re.subn("^system/", "", par) - (par, found_deploy_param) = re.subn("^deploy/", "", par) - - param_expr = par + "=" + value - if found_sys_param: - system_params.append(param_expr) - elif found_deploy_param: - deploy_params.append(par) - else: - application_params.append(param_expr) - - return ( - system_id, - application_name, - application_params, - system_params, - deploy_params, - report_path, - ) - - -def get_custom_deploy_data( - command_name: str, deploy_params: List[str] -) -> List[DataPaths]: - """Get custom deploy data information.""" - custom_deploy_data: List[DataPaths] = [] - if not deploy_params: - return custom_deploy_data - - for param in deploy_params: - parts = param.split(":") - if not len(parts) == 2 or any(not part.strip() for part in parts): - raise click.ClickException( - "Invalid deploy parameter '{}' for command {}".format( - param, command_name - ) - ) - data_path = DataPaths(Path(parts[0]), parts[1]) - if not data_path.src.exists(): - raise click.ClickException("Path {} does not exist".format(data_path.src)) - custom_deploy_data.append(data_path) - - return custom_deploy_data - - -@application_cmd.command(name="install") -@click.option( - "-s", - "--source", - "source", - required=True, - help="Path to the directory or archive with application definition", -) -def install_cmd(source: str) -> None: - """Install new application.""" - source_path = Path(source) - install_application(source_path) - - -@application_cmd.command(name="remove") -@click.option( - "-d", - "--directory_name", - "directory_name", - type=click.Choice(get_available_application_directory_names()), - required=True, - help="Name of the directory with application", -) -def remove_cmd(directory_name: str) -> None: - """Remove application.""" - remove_application(directory_name) diff --git a/src/aiet/cli/common.py b/src/aiet/cli/common.py deleted file mode 100644 index 1d157b6..0000000 --- a/src/aiet/cli/common.py +++ /dev/null @@ -1,173 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Common functions for cli module.""" -import enum -import logging -from functools import wraps -from signal import SIG_IGN -from signal import SIGINT -from signal import signal as signal_handler -from signal import SIGTERM -from typing import Any -from typing import Callable -from typing import cast -from typing import Dict - -from click import ClickException -from click import Context -from click import UsageError - -from aiet.backend.common import ConfigurationException -from aiet.backend.execution import AnotherInstanceIsRunningException -from aiet.backend.execution import ConnectionException -from aiet.backend.protocol import SSHConnectionException -from aiet.utils.proc import CommandFailedException - - -class MiddlewareExitCode(enum.IntEnum): - """Middleware exit codes.""" - - SUCCESS = 0 - # exit codes 1 and 2 are used by click - SHUTDOWN_REQUESTED = 3 - BACKEND_ERROR = 4 - CONCURRENT_ERROR = 5 - CONNECTION_ERROR = 6 - CONFIGURATION_ERROR = 7 - MODEL_OPTIMISED_ERROR = 8 - INVALID_TFLITE_FILE_ERROR = 9 - - -class CustomClickException(ClickException): - """Custom click exception.""" - - def show(self, file: Any = None) -> None: - """Override show method.""" - super().show(file) - - logging.debug("Execution failed with following exception: ", exc_info=self) - - -class MiddlewareShutdownException(CustomClickException): - """Exception indicates that user requested middleware shutdown.""" - - exit_code = int(MiddlewareExitCode.SHUTDOWN_REQUESTED) - - -class BackendException(CustomClickException): - """Exception indicates that command failed.""" - - exit_code = int(MiddlewareExitCode.BACKEND_ERROR) - - -class ConcurrentErrorException(CustomClickException): - """Exception indicates concurrent execution error.""" - - exit_code = int(MiddlewareExitCode.CONCURRENT_ERROR) - - -class BackendConnectionException(CustomClickException): - """Exception indicates that connection could not be established.""" - - exit_code = int(MiddlewareExitCode.CONNECTION_ERROR) - - -class BackendConfigurationException(CustomClickException): - """Exception indicates some configuration issue.""" - - exit_code = int(MiddlewareExitCode.CONFIGURATION_ERROR) - - -class ModelOptimisedException(CustomClickException): - """Exception indicates input file has previously been Vela optimised.""" - - exit_code = int(MiddlewareExitCode.MODEL_OPTIMISED_ERROR) - - -class InvalidTFLiteFileError(CustomClickException): - """Exception indicates input TFLite file is misformatted.""" - - exit_code = int(MiddlewareExitCode.INVALID_TFLITE_FILE_ERROR) - - -def print_command_details(command: Dict) -> None: - """Print command details including parameters.""" - command_strings = command["command_strings"] - print("Commands: {}".format(command_strings)) - user_params = command["user_params"] - for i, param in enumerate(user_params, 1): - print("User parameter #{}".format(i)) - print("\tName: {}".format(param.get("name", "-"))) - print("\tDescription: {}".format(param["description"])) - print("\tPossible values: {}".format(param.get("values", "-"))) - print("\tDefault value: {}".format(param.get("default_value", "-"))) - print("\tAlias: {}".format(param.get("alias", "-"))) - - -def raise_exception_at_signal( - signum: int, frame: Any # pylint: disable=unused-argument -) -> None: - """Handle signals.""" - # Disable both SIGINT and SIGTERM signals. Further SIGINT and SIGTERM - # signals will be ignored as we allow a graceful shutdown. - # Unused arguments must be present here in definition as used in signal handler - # callback - - signal_handler(SIGINT, SIG_IGN) - signal_handler(SIGTERM, SIG_IGN) - raise MiddlewareShutdownException("Middleware shutdown requested") - - -def middleware_exception_handler(func: Callable) -> Callable: - """Handle backend exceptions decorator.""" - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - try: - return func(*args, **kwargs) - except (MiddlewareShutdownException, UsageError, ClickException) as error: - # click should take care of these exceptions - raise error - except ValueError as error: - raise ClickException(str(error)) from error - except AnotherInstanceIsRunningException as error: - raise ConcurrentErrorException( - "Another instance of the system is running" - ) from error - except (SSHConnectionException, ConnectionException) as error: - raise BackendConnectionException(str(error)) from error - except ConfigurationException as error: - raise BackendConfigurationException(str(error)) from error - except (CommandFailedException, Exception) as error: - raise BackendException( - "Execution failed. Please check output for the details." - ) from error - - return wrapper - - -def middleware_signal_handler(func: Callable) -> Callable: - """Handle signals decorator.""" - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - # Set up signal handlers for SIGINT (ctrl-c) and SIGTERM (kill command) - # The handler ignores further signals and it raises an exception - signal_handler(SIGINT, raise_exception_at_signal) - signal_handler(SIGTERM, raise_exception_at_signal) - - return func(*args, **kwargs) - - return wrapper - - -def set_format(ctx: Context, format_: str) -> None: - """Save format in click context.""" - ctx_obj = ctx.ensure_object(dict) - ctx_obj["format"] = format_ - - -def get_format(ctx: Context) -> str: - """Get format from click context.""" - ctx_obj = cast(Dict[str, str], ctx.ensure_object(dict)) - return ctx_obj["format"] diff --git a/src/aiet/cli/completion.py b/src/aiet/cli/completion.py deleted file mode 100644 index 71f054f..0000000 --- a/src/aiet/cli/completion.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -""" -Add auto completion to different shells with these helpers. - -See: https://click.palletsprojects.com/en/8.0.x/shell-completion/ -""" -import click - - -def _get_package_name() -> str: - return __name__.split(".", maxsplit=1)[0] - - -# aiet completion bash -@click.group(name="completion") -def completion_cmd() -> None: - """Enable auto completion for your shell.""" - - -@completion_cmd.command(name="bash") -def bash_cmd() -> None: - """ - Enable auto completion for bash. - - Use this command to activate completion in the current bash: - - eval "`aiet completion bash`" - - Use this command to add auto completion to bash globally, if you have aiet - installed globally (requires starting a new shell afterwards): - - aiet completion bash >> ~/.bashrc - """ - package_name = _get_package_name() - print(f'eval "$(_{package_name.upper()}_COMPLETE=bash_source {package_name})"') - - -@completion_cmd.command(name="zsh") -def zsh_cmd() -> None: - """ - Enable auto completion for zsh. - - Use this command to activate completion in the current zsh: - - eval "`aiet completion zsh`" - - Use this command to add auto completion to zsh globally, if you have aiet - installed globally (requires starting a new shell afterwards): - - aiet completion zsh >> ~/.zshrc - """ - package_name = _get_package_name() - print(f'eval "$(_{package_name.upper()}_COMPLETE=zsh_source {package_name})"') - - -@completion_cmd.command(name="fish") -def fish_cmd() -> None: - """ - Enable auto completion for fish. - - Use this command to activate completion in the current fish: - - eval "`aiet completion fish`" - - Use this command to add auto completion to fish globally, if you have aiet - installed globally (requires starting a new shell afterwards): - - aiet completion fish >> ~/.config/fish/completions/aiet.fish - """ - package_name = _get_package_name() - print(f'eval "(env _{package_name.upper()}_COMPLETE=fish_source {package_name})"') diff --git a/src/aiet/cli/system.py b/src/aiet/cli/system.py deleted file mode 100644 index f1f7637..0000000 --- a/src/aiet/cli/system.py +++ /dev/null @@ -1,122 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module to manage the CLI interface of systems.""" -import json -from pathlib import Path -from typing import cast - -import click - -from aiet.backend.application import get_available_applications -from aiet.backend.system import get_available_systems -from aiet.backend.system import get_available_systems_directory_names -from aiet.backend.system import get_system -from aiet.backend.system import install_system -from aiet.backend.system import remove_system -from aiet.backend.system import System -from aiet.cli.common import get_format -from aiet.cli.common import print_command_details -from aiet.cli.common import set_format - - -@click.group(name="system") -@click.option( - "-f", - "--format", - "format_", - type=click.Choice(["cli", "json"]), - default="cli", - show_default=True, -) -@click.pass_context -def system_cmd(ctx: click.Context, format_: str) -> None: - """Sub command to manage systems.""" - set_format(ctx, format_) - - -@system_cmd.command(name="list") -@click.pass_context -def list_cmd(ctx: click.Context) -> None: - """List all available systems.""" - available_systems = get_available_systems() - system_names = [system.name for system in available_systems] - if get_format(ctx) == "json": - data = {"type": "system", "available": system_names} - print(json.dumps(data)) - else: - print("Available systems:\n") - print(*system_names, sep="\n") - - -@system_cmd.command(name="details") -@click.option( - "-n", - "--name", - "system_name", - type=click.Choice([s.name for s in get_available_systems()]), - required=True, -) -@click.pass_context -def details_cmd(ctx: click.Context, system_name: str) -> None: - """Details of a specific system.""" - system = cast(System, get_system(system_name)) - applications = [ - s.name for s in get_available_applications() if s.can_run_on(system.name) - ] - system_details = system.get_details() - if get_format(ctx) == "json": - system_details["available_application"] = applications - print(json.dumps(system_details)) - else: - system_details_template = ( - 'System "{name}" details\n' - "Description: {description}\n" - "Data Transfer Protocol: {protocol}\n" - "Available Applications: {available_application}" - ) - print( - system_details_template.format( - name=system_details["name"], - description=system_details["description"], - protocol=system_details["data_transfer_protocol"], - available_application=", ".join(applications), - ) - ) - - if system_details["annotations"]: - print("Annotations:") - for ann_name, ann_value in system_details["annotations"].items(): - print("\t{}: {}".format(ann_name, ann_value)) - - command_details = system_details["commands"] - for command, details in command_details.items(): - print("\n{} commands:".format(command)) - print_command_details(details) - - -@system_cmd.command(name="install") -@click.option( - "-s", - "--source", - "source", - required=True, - help="Path to the directory or archive with system definition", -) -def install_cmd(source: str) -> None: - """Install new system.""" - source_path = Path(source) - install_system(source_path) - - -@system_cmd.command(name="remove") -@click.option( - "-d", - "--directory_name", - "directory_name", - type=click.Choice(get_available_systems_directory_names()), - required=True, - help="Name of the directory with system", -) -def remove_cmd(directory_name: str) -> None: - """Remove system by given name.""" - remove_system(directory_name) diff --git a/src/aiet/cli/tool.py b/src/aiet/cli/tool.py deleted file mode 100644 index 2c80821..0000000 --- a/src/aiet/cli/tool.py +++ /dev/null @@ -1,143 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module to manage the CLI interface of tools.""" -import json -from typing import Any -from typing import List -from typing import Optional - -import click - -from aiet.backend.execution import execute_tool_command -from aiet.backend.tool import get_tool -from aiet.backend.tool import get_unique_tool_names -from aiet.cli.common import get_format -from aiet.cli.common import middleware_exception_handler -from aiet.cli.common import middleware_signal_handler -from aiet.cli.common import print_command_details -from aiet.cli.common import set_format - - -@click.group(name="tool") -@click.option( - "-f", - "--format", - "format_", - type=click.Choice(["cli", "json"]), - default="cli", - show_default=True, -) -@click.pass_context -def tool_cmd(ctx: click.Context, format_: str) -> None: - """Sub command to manage tools.""" - set_format(ctx, format_) - - -@tool_cmd.command(name="list") -@click.pass_context -def list_cmd(ctx: click.Context) -> None: - """List all available tools.""" - # raise NotImplementedError("TODO") - tool_names = get_unique_tool_names() - tool_names.sort() - if get_format(ctx) == "json": - data = {"type": "tool", "available": tool_names} - print(json.dumps(data)) - else: - print("Available tools:\n") - print(*tool_names, sep="\n") - - -def validate_system( - ctx: click.Context, - _: click.Parameter, # param is not used - value: Any, -) -> Any: - """Validate provided system name depending on the the tool name.""" - tool_name = ctx.params["tool_name"] - tools = get_tool(tool_name, value) - if not tools: - supported_systems = [tool.supported_systems[0] for tool in get_tool(tool_name)] - raise click.BadParameter( - message="'{}' is not one of {}.".format( - value, - ", ".join("'{}'".format(system) for system in supported_systems), - ), - ctx=ctx, - ) - return value - - -@tool_cmd.command(name="details") -@click.option( - "-n", - "--name", - "tool_name", - type=click.Choice(get_unique_tool_names()), - required=True, -) -@click.option( - "-s", - "--system", - "system_name", - callback=validate_system, - required=False, -) -@click.pass_context -@middleware_signal_handler -@middleware_exception_handler -def details_cmd(ctx: click.Context, tool_name: str, system_name: Optional[str]) -> None: - """Details of a specific tool.""" - tools = get_tool(tool_name, system_name) - if get_format(ctx) == "json": - tools_details = [s.get_details() for s in tools] - print(json.dumps(tools_details)) - else: - for tool in tools: - tool_details = tool.get_details() - tool_details_template = 'Tool "{name}" details\nDescription: {description}' - - print( - tool_details_template.format( - name=tool_details["name"], - description=tool_details["description"], - ) - ) - - print( - "\nSupported systems: {}".format( - ", ".join(tool_details["supported_systems"]) - ) - ) - - command_details = tool_details["commands"] - - for command, details in command_details.items(): - print("\n{} commands:".format(command)) - print_command_details(details) - - -# pylint: disable=too-many-arguments -@tool_cmd.command(name="execute") -@click.option( - "-n", - "--name", - "tool_name", - type=click.Choice(get_unique_tool_names()), - required=True, -) -@click.option("-p", "--param", "tool_params", multiple=True) -@click.option( - "-s", - "--system", - "system_name", - callback=validate_system, - required=False, -) -@middleware_signal_handler -@middleware_exception_handler -def execute_cmd( - tool_name: str, tool_params: List[str], system_name: Optional[str] -) -> None: - """Execute tool commands.""" - execute_tool_command(tool_name, tool_params, system_name) diff --git a/src/aiet/main.py b/src/aiet/main.py deleted file mode 100644 index 6898ad9..0000000 --- a/src/aiet/main.py +++ /dev/null @@ -1,13 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Entry point module of AIET.""" -from aiet.cli import cli - - -def main() -> None: - """Entry point of aiet application.""" - cli() # pylint: disable=no-value-for-parameter - - -if __name__ == "__main__": - main() diff --git a/src/aiet/resources/applications/.gitignore b/src/aiet/resources/applications/.gitignore deleted file mode 100644 index 0226166..0000000 --- a/src/aiet/resources/applications/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# Ignore everything in this directory -* -# Except this file -!.gitignore diff --git a/src/aiet/resources/systems/.gitignore b/src/aiet/resources/systems/.gitignore deleted file mode 100644 index 0226166..0000000 --- a/src/aiet/resources/systems/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# Ignore everything in this directory -* -# Except this file -!.gitignore diff --git a/src/aiet/resources/tools/vela/aiet-config.json b/src/aiet/resources/tools/vela/aiet-config.json deleted file mode 100644 index c12f291..0000000 --- a/src/aiet/resources/tools/vela/aiet-config.json +++ /dev/null @@ -1,73 +0,0 @@ -[ - { - "name": "vela", - "description": "Neural network model compiler for Arm Ethos-U NPUs", - "supported_systems": [ - { - "name": "Corstone-300: Cortex-M55+Ethos-U55" - }, - { - "name": "Corstone-310: Cortex-M85+Ethos-U55" - }, - { - "name": "Corstone-300: Cortex-M55+Ethos-U65", - "variables": { - "accelerator_config_prefix": "ethos-u65", - "system_config": "Ethos_U65_High_End", - "shared_sram": "U65_Shared_Sram" - }, - "user_params": { - "run": [ - { - "description": "MACs per cycle", - "values": [ - "256", - "512" - ], - "default_value": "512", - "alias": "mac" - } - ] - } - } - ], - "variables": { - "accelerator_config_prefix": "ethos-u55", - "system_config": "Ethos_U55_High_End_Embedded", - "shared_sram": "U55_Shared_Sram" - }, - "commands": { - "run": [ - "run_vela {user_params:input} {user_params:output} --config {tool.config_dir}/vela.ini --accelerator-config {variables:accelerator_config_prefix}-{user_params:mac} --system-config {variables:system_config} --memory-mode {variables:shared_sram} --optimise Performance" - ] - }, - "user_params": { - "run": [ - { - "description": "MACs per cycle", - "values": [ - "32", - "64", - "128", - "256" - ], - "default_value": "128", - "alias": "mac" - }, - { - "name": "--input-model", - "description": "Path to the TFLite model", - "values": [], - "alias": "input" - }, - { - "name": "--output-model", - "description": "Path to the output model file of the vela-optimisation step. The vela output is saved in the parent directory.", - "values": [], - "default_value": "output_model.tflite", - "alias": "output" - } - ] - } - } -] diff --git a/src/aiet/resources/tools/vela/aiet-config.json.license b/src/aiet/resources/tools/vela/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/src/aiet/resources/tools/vela/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/src/aiet/resources/tools/vela/check_model.py b/src/aiet/resources/tools/vela/check_model.py deleted file mode 100644 index 7c700b1..0000000 --- a/src/aiet/resources/tools/vela/check_model.py +++ /dev/null @@ -1,75 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Check if a TFLite model file is Vela-optimised.""" -import struct -from pathlib import Path - -from ethosu.vela.tflite.Model import Model - -from aiet.cli.common import InvalidTFLiteFileError -from aiet.cli.common import ModelOptimisedException -from aiet.utils.fs import read_file_as_bytearray - - -def get_model_from_file(input_model_file: Path) -> Model: - """Generate Model instance from TFLite file using flatc generated code.""" - buffer = read_file_as_bytearray(input_model_file) - try: - model = Model.GetRootAsModel(buffer, 0) - except (TypeError, RuntimeError, struct.error) as tflite_error: - raise InvalidTFLiteFileError( - f"Error reading in model from {input_model_file}." - ) from tflite_error - return model - - -def is_vela_optimised(tflite_model: Model) -> bool: - """Return True if 'ethos-u' custom operator found in the Model.""" - operators = get_operators_from_model(tflite_model) - - custom_codes = get_custom_codes_from_operators(operators) - - return check_custom_codes_for_ethosu(custom_codes) - - -def get_operators_from_model(tflite_model: Model) -> list: - """Return list of the unique operator codes used in the Model.""" - return [ - tflite_model.OperatorCodes(index) - for index in range(tflite_model.OperatorCodesLength()) - ] - - -def get_custom_codes_from_operators(operators: list) -> list: - """Return list of each operator's CustomCode() strings, if they exist.""" - return [ - operator.CustomCode() - for operator in operators - if operator.CustomCode() is not None - ] - - -def check_custom_codes_for_ethosu(custom_codes: list) -> bool: - """Check for existence of ethos-u string in the custom codes.""" - return any( - custom_code_name.decode("utf-8") == "ethos-u" - for custom_code_name in custom_codes - ) - - -def check_model(tflite_file_name: str) -> None: - """Raise an exception if model in given file is Vela optimised.""" - tflite_path = Path(tflite_file_name) - - tflite_model = get_model_from_file(tflite_path) - - if is_vela_optimised(tflite_model): - raise ModelOptimisedException( - f"TFLite model in {tflite_file_name} is already " - f"vela optimised ('ethos-u' custom op detected)." - ) - - print( - f"TFLite model in {tflite_file_name} is not vela optimised " - f"('ethos-u' custom op not detected)." - ) diff --git a/src/aiet/resources/tools/vela/run_vela.py b/src/aiet/resources/tools/vela/run_vela.py deleted file mode 100644 index 2c1b0be..0000000 --- a/src/aiet/resources/tools/vela/run_vela.py +++ /dev/null @@ -1,65 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Wrapper to only run Vela when the input is not already optimised.""" -import shutil -import subprocess -from pathlib import Path -from typing import Tuple - -import click - -from aiet.cli.common import ModelOptimisedException -from aiet.resources.tools.vela.check_model import check_model - - -def vela_output_model_path(input_model: str, output_dir: str) -> Path: - """Construct the path to the Vela output file.""" - in_path = Path(input_model) - tflite_vela = Path(output_dir) / f"{in_path.stem}_vela{in_path.suffix}" - return tflite_vela - - -def execute_vela(vela_args: Tuple, output_dir: Path, input_model: str) -> None: - """Execute vela as external call.""" - cmd = ["vela"] + list(vela_args) - cmd += ["--output-dir", str(output_dir)] # Re-add parsed out_dir to arguments - cmd += [input_model] - subprocess.run(cmd, check=True) - - -@click.command(context_settings=dict(ignore_unknown_options=True)) -@click.option( - "--input-model", - "-i", - type=click.Path(exists=True, file_okay=True, readable=True), - required=True, -) -@click.option("--output-model", "-o", type=click.Path(), required=True) -# Collect the remaining arguments to be directly forwarded to Vela -@click.argument("vela-args", nargs=-1, type=click.UNPROCESSED) -def run_vela(input_model: str, output_model: str, vela_args: Tuple) -> None: - """Check input, run Vela (if needed) and copy optimised file to destination.""" - output_dir = Path(output_model).parent - try: - check_model(input_model) # raises an exception if already Vela-optimised - execute_vela(vela_args, output_dir, input_model) - print("Vela optimisation complete.") - src_model = vela_output_model_path(input_model, str(output_dir)) - except ModelOptimisedException as ex: - # Input already optimized: copy input file to destination path and return - print(f"Input already vela-optimised.\n{ex}") - src_model = Path(input_model) - except subprocess.CalledProcessError as ex: - print(ex) - raise SystemExit(ex.returncode) from ex - - try: - shutil.copyfile(src_model, output_model) - except (shutil.SameFileError, OSError) as ex: - print(ex) - raise SystemExit(ex.errno) from ex - - -def main() -> None: - """Entry point of check_model application.""" - run_vela() # pylint: disable=no-value-for-parameter diff --git a/src/aiet/resources/tools/vela/vela.ini b/src/aiet/resources/tools/vela/vela.ini deleted file mode 100644 index 5996553..0000000 --- a/src/aiet/resources/tools/vela/vela.ini +++ /dev/null @@ -1,53 +0,0 @@ -; SPDX-FileCopyrightText: Copyright 2021-2022, Arm Limited and/or its affiliates. -; SPDX-License-Identifier: Apache-2.0 - -; ----------------------------------------------------------------------------- -; Vela configuration file - -; ----------------------------------------------------------------------------- -; System Configuration - -; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s) -[System_Config.Ethos_U55_High_End_Embedded] -core_clock=500e6 -axi0_port=Sram -axi1_port=OffChipFlash -Sram_clock_scale=1.0 -Sram_burst_length=32 -Sram_read_latency=32 -Sram_write_latency=32 -OffChipFlash_clock_scale=0.125 -OffChipFlash_burst_length=128 -OffChipFlash_read_latency=64 -OffChipFlash_write_latency=64 - -; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s) -[System_Config.Ethos_U65_High_End] -core_clock=1e9 -axi0_port=Sram -axi1_port=Dram -Sram_clock_scale=1.0 -Sram_burst_length=32 -Sram_read_latency=32 -Sram_write_latency=32 -Dram_clock_scale=0.234375 -Dram_burst_length=128 -Dram_read_latency=500 -Dram_write_latency=250 - -; ----------------------------------------------------------------------------- -; Memory Mode - -; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software -; The non-SRAM memory is assumed to be read-only -[Memory_Mode.U55_Shared_Sram] -const_mem_area=Axi1 -arena_mem_area=Axi0 -cache_mem_area=Axi0 -arena_cache_size=4194304 - -[Memory_Mode.U65_Shared_Sram] -const_mem_area=Axi1 -arena_mem_area=Axi0 -cache_mem_area=Axi0 -arena_cache_size=2097152 diff --git a/src/aiet/utils/__init__.py b/src/aiet/utils/__init__.py deleted file mode 100644 index fc7ef7c..0000000 --- a/src/aiet/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""This module contains all utils shared across aiet project.""" diff --git a/src/aiet/utils/fs.py b/src/aiet/utils/fs.py deleted file mode 100644 index ea99a69..0000000 --- a/src/aiet/utils/fs.py +++ /dev/null @@ -1,116 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module to host all file system related functions.""" -import importlib.resources as pkg_resources -import re -import shutil -from pathlib import Path -from typing import Any -from typing import Literal -from typing import Optional - -ResourceType = Literal["applications", "systems", "tools"] - - -def get_aiet_resources() -> Path: - """Get resources folder path.""" - with pkg_resources.path("aiet", "__init__.py") as init_path: - project_root = init_path.parent - return project_root / "resources" - - -def get_resources(name: ResourceType) -> Path: - """Return the absolute path of the specified resource. - - It uses importlib to return resources packaged with MANIFEST.in. - """ - if not name: - raise ResourceWarning("Resource name is not provided") - - resource_path = get_aiet_resources() / name - if resource_path.is_dir(): - return resource_path - - raise ResourceWarning("Resource '{}' not found.".format(name)) - - -def copy_directory_content(source: Path, destination: Path) -> None: - """Copy content of the source directory into destination directory.""" - for item in source.iterdir(): - src = source / item.name - dest = destination / item.name - - if src.is_dir(): - shutil.copytree(src, dest) - else: - shutil.copy2(src, dest) - - -def remove_resource(resource_directory: str, resource_type: ResourceType) -> None: - """Remove resource data.""" - resources = get_resources(resource_type) - - resource_location = resources / resource_directory - if not resource_location.exists(): - raise Exception("Resource {} does not exist".format(resource_directory)) - - if not resource_location.is_dir(): - raise Exception("Wrong resource {}".format(resource_directory)) - - shutil.rmtree(resource_location) - - -def remove_directory(directory_path: Optional[Path]) -> None: - """Remove directory.""" - if not directory_path or not directory_path.is_dir(): - raise Exception("No directory path provided") - - shutil.rmtree(directory_path) - - -def recreate_directory(directory_path: Optional[Path]) -> None: - """Recreate directory.""" - if not directory_path: - raise Exception("No directory path provided") - - if directory_path.exists() and not directory_path.is_dir(): - raise Exception( - "Path {} does exist and it is not a directory".format(str(directory_path)) - ) - - if directory_path.is_dir(): - remove_directory(directory_path) - - directory_path.mkdir() - - -def read_file(file_path: Path, mode: Optional[str] = None) -> Any: - """Read file as string or bytearray.""" - if file_path.is_file(): - if mode is not None: - # Ignore pylint warning because mode can be 'binary' as well which - # is not compatible with specifying encodings. - with open(file_path, mode) as file: # pylint: disable=unspecified-encoding - return file.read() - else: - with open(file_path, encoding="utf-8") as file: - return file.read() - - if mode == "rb": - return b"" - return "" - - -def read_file_as_string(file_path: Path) -> str: - """Read file as string.""" - return str(read_file(file_path)) - - -def read_file_as_bytearray(file_path: Path) -> bytearray: - """Read a file as bytearray.""" - return bytearray(read_file(file_path, mode="rb")) - - -def valid_for_filename(value: str, replacement: str = "") -> str: - """Replace non alpha numeric characters.""" - return re.sub(r"[^\w.]", replacement, value, flags=re.ASCII) diff --git a/src/aiet/utils/helpers.py b/src/aiet/utils/helpers.py deleted file mode 100644 index 6d3cd22..0000000 --- a/src/aiet/utils/helpers.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Helpers functions.""" -import logging -from typing import Any - - -def set_verbosity( - ctx: Any, option: Any, verbosity: Any # pylint: disable=unused-argument -) -> None: - """Set the logging level according to the verbosity.""" - # Unused arguments must be present here in definition as these are required in - # function definition when set as a callback - if verbosity == 1: - logging.getLogger().setLevel(logging.INFO) - elif verbosity > 1: - logging.getLogger().setLevel(logging.DEBUG) diff --git a/src/aiet/utils/proc.py b/src/aiet/utils/proc.py deleted file mode 100644 index b6f4357..0000000 --- a/src/aiet/utils/proc.py +++ /dev/null @@ -1,283 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Processes module. - -This module contains all classes and functions for dealing with Linux -processes. -""" -import csv -import datetime -import logging -import shlex -import signal -import time -from pathlib import Path -from typing import Any -from typing import List -from typing import NamedTuple -from typing import Optional -from typing import Tuple - -import psutil -from sh import Command -from sh import CommandNotFound -from sh import ErrorReturnCode -from sh import RunningCommand - -from aiet.utils.fs import valid_for_filename - - -class CommandFailedException(Exception): - """Exception for failed command execution.""" - - -class ShellCommand: - """Wrapper class for shell commands.""" - - def __init__(self, base_log_path: str = "/tmp") -> None: - """Initialise the class. - - base_log_path: it is the base directory where logs will be stored - """ - self.base_log_path = base_log_path - - def run( - self, - cmd: str, - *args: str, - _cwd: Optional[Path] = None, - _tee: bool = True, - _bg: bool = True, - _out: Any = None, - _err: Any = None, - _search_paths: Optional[List[Path]] = None - ) -> RunningCommand: - """Run the shell command with the given arguments. - - There are special arguments that modify the behaviour of the process. - _cwd: current working directory - _tee: it redirects the stdout both to console and file - _bg: if True, it runs the process in background and the command is not - blocking. - _out: use this object for stdout redirect, - _err: use this object for stderr redirect, - _search_paths: If presented used for searching executable - """ - try: - kwargs = {} - if _cwd: - kwargs["_cwd"] = str(_cwd) - command = Command(cmd, _search_paths).bake(args, **kwargs) - except CommandNotFound as error: - logging.error("Command '%s' not found", error.args[0]) - raise error - - out, err = _out, _err - if not _out and not _err: - out, err = [ - str(item) - for item in self.get_stdout_stderr_paths(self.base_log_path, cmd) - ] - - return command(_out=out, _err=err, _tee=_tee, _bg=_bg, _bg_exc=False) - - @classmethod - def get_stdout_stderr_paths(cls, base_log_path: str, cmd: str) -> Tuple[Path, Path]: - """Construct and returns the paths of stdout/stderr files.""" - timestamp = datetime.datetime.now().timestamp() - base_path = Path(base_log_path) - base = base_path / "{}_{}".format(valid_for_filename(cmd, "_"), timestamp) - stdout = base.with_suffix(".out") - stderr = base.with_suffix(".err") - try: - stdout.touch() - stderr.touch() - except FileNotFoundError as error: - logging.error("File not found: %s", error.filename) - raise error - return stdout, stderr - - -def parse_command(command: str, shell: str = "bash") -> List[str]: - """Parse command.""" - cmd, *args = shlex.split(command, posix=True) - - if is_shell_script(cmd): - args = [cmd] + args - cmd = shell - - return [cmd] + args - - -def get_stdout_stderr_paths( - command: str, base_log_path: str = "/tmp" -) -> Tuple[Path, Path]: - """Construct and returns the paths of stdout/stderr files.""" - cmd, *_ = parse_command(command) - - return ShellCommand.get_stdout_stderr_paths(base_log_path, cmd) - - -def execute_command( # pylint: disable=invalid-name - command: str, - cwd: Path, - bg: bool = False, - shell: str = "bash", - out: Any = None, - err: Any = None, -) -> RunningCommand: - """Execute shell command.""" - cmd, *args = parse_command(command, shell) - - search_paths = None - if cmd != shell and (cwd / cmd).is_file(): - search_paths = [cwd] - - return ShellCommand().run( - cmd, *args, _cwd=cwd, _bg=bg, _search_paths=search_paths, _out=out, _err=err - ) - - -def is_shell_script(cmd: str) -> bool: - """Check if command is shell script.""" - return cmd.endswith(".sh") - - -def run_and_wait( - command: str, - cwd: Path, - terminate_on_error: bool = False, - out: Any = None, - err: Any = None, -) -> Tuple[int, bytearray, bytearray]: - """ - Run command and wait while it is executing. - - Returns a tuple: (exit_code, stdout, stderr) - """ - running_cmd: Optional[RunningCommand] = None - try: - running_cmd = execute_command(command, cwd, bg=True, out=out, err=err) - return running_cmd.exit_code, running_cmd.stdout, running_cmd.stderr - except ErrorReturnCode as cmd_failed: - raise CommandFailedException() from cmd_failed - except Exception as error: - is_running = running_cmd is not None and running_cmd.is_alive() - if terminate_on_error and is_running: - print("Terminating ...") - terminate_command(running_cmd) - - raise error - - -def terminate_command( - running_cmd: RunningCommand, - wait: bool = True, - wait_period: float = 0.5, - number_of_attempts: int = 20, -) -> None: - """Terminate running command.""" - try: - running_cmd.process.signal_group(signal.SIGINT) - if wait: - for _ in range(number_of_attempts): - time.sleep(wait_period) - if not running_cmd.is_alive(): - return - print( - "Unable to terminate process {}. Sending SIGTERM...".format( - running_cmd.process.pid - ) - ) - running_cmd.process.signal_group(signal.SIGTERM) - except ProcessLookupError: - pass - - -def terminate_external_process( - process: psutil.Process, - wait_period: float = 0.5, - number_of_attempts: int = 20, - wait_for_termination: float = 5.0, -) -> None: - """Terminate external process.""" - try: - process.terminate() - for _ in range(number_of_attempts): - if not process.is_running(): - return - time.sleep(wait_period) - - if process.is_running(): - process.terminate() - time.sleep(wait_for_termination) - except psutil.Error: - print("Unable to terminate process") - - -class ProcessInfo(NamedTuple): - """Process information.""" - - name: str - executable: str - cwd: str - pid: int - - -def save_process_info(pid: int, pid_file: Path) -> None: - """Save process information to file.""" - try: - parent = psutil.Process(pid) - children = parent.children(recursive=True) - family = [parent] + children - - with open(pid_file, "w", encoding="utf-8") as file: - csv_writer = csv.writer(file) - for member in family: - process_info = ProcessInfo( - name=member.name(), - executable=member.exe(), - cwd=member.cwd(), - pid=member.pid, - ) - csv_writer.writerow(process_info) - except psutil.NoSuchProcess: - # if process does not exist or finishes before - # function call then nothing could be saved - # just ignore this exception and exit - pass - - -def read_process_info(pid_file: Path) -> List[ProcessInfo]: - """Read information about previous system processes.""" - if not pid_file.is_file(): - return [] - - result = [] - with open(pid_file, encoding="utf-8") as file: - csv_reader = csv.reader(file) - for row in csv_reader: - name, executable, cwd, pid = row - result.append( - ProcessInfo(name=name, executable=executable, cwd=cwd, pid=int(pid)) - ) - - return result - - -def print_command_stdout(command: RunningCommand) -> None: - """Print the stdout of a command. - - The command has 2 states: running and done. - If the command is running, the output is taken by the running process. - If the command has ended its execution, the stdout is taken from stdout - property - """ - if command.is_alive(): - while True: - try: - print(command.next(), end="") - except StopIteration: - break - else: - print(command.stdout) |