diff options
Diffstat (limited to 'src/aiet/backend')
-rw-r--r-- | src/aiet/backend/__init__.py | 3 | ||||
-rw-r--r-- | src/aiet/backend/application.py | 187 | ||||
-rw-r--r-- | src/aiet/backend/common.py | 532 | ||||
-rw-r--r-- | src/aiet/backend/config.py | 107 | ||||
-rw-r--r-- | src/aiet/backend/controller.py | 134 | ||||
-rw-r--r-- | src/aiet/backend/execution.py | 859 | ||||
-rw-r--r-- | src/aiet/backend/output_parser.py | 176 | ||||
-rw-r--r-- | src/aiet/backend/protocol.py | 325 | ||||
-rw-r--r-- | src/aiet/backend/source.py | 209 | ||||
-rw-r--r-- | src/aiet/backend/system.py | 289 | ||||
-rw-r--r-- | src/aiet/backend/tool.py | 109 |
11 files changed, 0 insertions, 2930 deletions
diff --git a/src/aiet/backend/__init__.py b/src/aiet/backend/__init__.py deleted file mode 100644 index 3d60372..0000000 --- a/src/aiet/backend/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Backend module.""" diff --git a/src/aiet/backend/application.py b/src/aiet/backend/application.py deleted file mode 100644 index f6ef815..0000000 --- a/src/aiet/backend/application.py +++ /dev/null @@ -1,187 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Application backend module.""" -import re -from pathlib import Path -from typing import Any -from typing import cast -from typing import Dict -from typing import List -from typing import Optional - -from aiet.backend.common import Backend -from aiet.backend.common import ConfigurationException -from aiet.backend.common import DataPaths -from aiet.backend.common import get_backend_configs -from aiet.backend.common import get_backend_directories -from aiet.backend.common import load_application_or_tool_configs -from aiet.backend.common import load_config -from aiet.backend.common import remove_backend -from aiet.backend.config import ApplicationConfig -from aiet.backend.config import ExtendedApplicationConfig -from aiet.backend.source import create_destination_and_install -from aiet.backend.source import get_source -from aiet.utils.fs import get_resources - - -def get_available_application_directory_names() -> List[str]: - """Return a list of directory names for all available applications.""" - return [entry.name for entry in get_backend_directories("applications")] - - -def get_available_applications() -> List["Application"]: - """Return a list with all available applications.""" - available_applications = [] - for config_json in get_backend_configs("applications"): - config_entries = cast(List[ExtendedApplicationConfig], load_config(config_json)) - for config_entry in config_entries: - config_entry["config_location"] = config_json.parent.absolute() - applications = load_applications(config_entry) - available_applications += applications - - return sorted(available_applications, key=lambda application: application.name) - - -def get_application( - application_name: str, system_name: Optional[str] = None -) -> List["Application"]: - """Return a list of application instances with provided name.""" - return [ - application - for application in get_available_applications() - if application.name == application_name - and (not system_name or application.can_run_on(system_name)) - ] - - -def install_application(source_path: Path) -> None: - """Install application.""" - try: - source = get_source(source_path) - config = cast(List[ExtendedApplicationConfig], source.config()) - applications_to_install = [ - s for entry in config for s in load_applications(entry) - ] - except Exception as error: - raise ConfigurationException("Unable to read application definition") from error - - if not applications_to_install: - raise ConfigurationException("No application definition found") - - available_applications = get_available_applications() - already_installed = [ - s for s in applications_to_install if s in available_applications - ] - if already_installed: - names = {application.name for application in already_installed} - raise ConfigurationException( - "Applications [{}] are already installed".format(",".join(names)) - ) - - create_destination_and_install(source, get_resources("applications")) - - -def remove_application(directory_name: str) -> None: - """Remove application directory.""" - remove_backend(directory_name, "applications") - - -def get_unique_application_names(system_name: Optional[str] = None) -> List[str]: - """Extract a list of unique application names of all application available.""" - return list( - set( - application.name - for application in get_available_applications() - if not system_name or application.can_run_on(system_name) - ) - ) - - -class Application(Backend): - """Class for representing a single application component.""" - - def __init__(self, config: ApplicationConfig) -> None: - """Construct a Application instance from a dict.""" - super().__init__(config) - - self.supported_systems = config.get("supported_systems", []) - self.deploy_data = config.get("deploy_data", []) - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Application): - return False - - return ( - super().__eq__(other) - and self.name == other.name - and set(self.supported_systems) == set(other.supported_systems) - ) - - def can_run_on(self, system_name: str) -> bool: - """Check if the application can run on the system passed as argument.""" - return system_name in self.supported_systems - - def get_deploy_data(self) -> List[DataPaths]: - """Validate and return data specified in the config file.""" - if self.config_location is None: - raise ConfigurationException( - "Unable to get application {} config location".format(self.name) - ) - - deploy_data = [] - for item in self.deploy_data: - src, dst = item - src_full_path = self.config_location / src - assert src_full_path.exists(), "{} does not exists".format(src_full_path) - deploy_data.append(DataPaths(src_full_path, dst)) - return deploy_data - - def get_details(self) -> Dict[str, Any]: - """Return dictionary with information about the Application instance.""" - output = { - "type": "application", - "name": self.name, - "description": self.description, - "supported_systems": self.supported_systems, - "commands": self._get_command_details(), - } - - return output - - def remove_unused_params(self) -> None: - """Remove unused params in commands. - - After merging default and system related configuration application - could have parameters that are not being used in commands. They - should be removed. - """ - for command in self.commands.values(): - indexes_or_aliases = [ - m - for cmd_str in command.command_strings - for m in re.findall(r"{user_params:(?P<index_or_alias>\w+)}", cmd_str) - ] - - only_aliases = all(not item.isnumeric() for item in indexes_or_aliases) - if only_aliases: - used_params = [ - param - for param in command.params - if param.alias in indexes_or_aliases - ] - command.params = used_params - - -def load_applications(config: ExtendedApplicationConfig) -> List[Application]: - """Load application. - - Application configuration could contain different parameters/commands for different - supported systems. For each supported system this function will return separate - Application instance with appropriate configuration. - """ - configs = load_application_or_tool_configs(config, ApplicationConfig) - applications = [Application(cfg) for cfg in configs] - for application in applications: - application.remove_unused_params() - return applications diff --git a/src/aiet/backend/common.py b/src/aiet/backend/common.py deleted file mode 100644 index b887ee7..0000000 --- a/src/aiet/backend/common.py +++ /dev/null @@ -1,532 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain all common functions for the backends.""" -import json -import logging -import re -from abc import ABC -from collections import Counter -from pathlib import Path -from typing import Any -from typing import Callable -from typing import cast -from typing import Dict -from typing import Final -from typing import IO -from typing import Iterable -from typing import List -from typing import Match -from typing import NamedTuple -from typing import Optional -from typing import Pattern -from typing import Tuple -from typing import Type -from typing import Union - -from aiet.backend.config import BackendConfig -from aiet.backend.config import BaseBackendConfig -from aiet.backend.config import NamedExecutionConfig -from aiet.backend.config import UserParamConfig -from aiet.backend.config import UserParamsConfig -from aiet.utils.fs import get_resources -from aiet.utils.fs import remove_resource -from aiet.utils.fs import ResourceType - - -AIET_CONFIG_FILE: Final[str] = "aiet-config.json" - - -class ConfigurationException(Exception): - """Configuration exception.""" - - -def get_backend_config(dir_path: Path) -> Path: - """Get path to backendir configuration file.""" - return dir_path / AIET_CONFIG_FILE - - -def get_backend_configs(resource_type: ResourceType) -> Iterable[Path]: - """Get path to the backend configs for provided resource_type.""" - return ( - get_backend_config(entry) for entry in get_backend_directories(resource_type) - ) - - -def get_backend_directories(resource_type: ResourceType) -> Iterable[Path]: - """Get path to the backend directories for provided resource_type.""" - return ( - entry - for entry in get_resources(resource_type).iterdir() - if is_backend_directory(entry) - ) - - -def is_backend_directory(dir_path: Path) -> bool: - """Check if path is backend's configuration directory.""" - return dir_path.is_dir() and get_backend_config(dir_path).is_file() - - -def remove_backend(directory_name: str, resource_type: ResourceType) -> None: - """Remove backend with provided type and directory_name.""" - if not directory_name: - raise Exception("No directory name provided") - - remove_resource(directory_name, resource_type) - - -def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig: - """Return a loaded json file.""" - if config is None: - raise Exception("Unable to read config") - - if isinstance(config, Path): - with config.open() as json_file: - return cast(BackendConfig, json.load(json_file)) - - return cast(BackendConfig, json.load(config)) - - -def parse_raw_parameter(parameter: str) -> Tuple[str, Optional[str]]: - """Split the parameter string in name and optional value. - - It manages the following cases: - --param=1 -> --param, 1 - --param 1 -> --param, 1 - --flag -> --flag, None - """ - data = re.split(" |=", parameter) - if len(data) == 1: - param_name = data[0] - param_value = None - else: - param_name = " ".join(data[0:-1]) - param_value = data[-1] - return param_name, param_value - - -class DataPaths(NamedTuple): - """DataPaths class.""" - - src: Path - dst: str - - -class Backend(ABC): - """Backend class.""" - - # pylint: disable=too-many-instance-attributes - - def __init__(self, config: BaseBackendConfig): - """Initialize backend.""" - name = config.get("name") - if not name: - raise ConfigurationException("Name is empty") - - self.name = name - self.description = config.get("description", "") - self.config_location = config.get("config_location") - self.variables = config.get("variables", {}) - self.build_dir = config.get("build_dir") - self.lock = config.get("lock", False) - if self.build_dir: - self.build_dir = self._substitute_variables(self.build_dir) - self.annotations = config.get("annotations", {}) - - self._parse_commands_and_params(config) - - def validate_parameter(self, command_name: str, parameter: str) -> bool: - """Validate the parameter string against the application configuration. - - We take the parameter string, extract the parameter name/value and - check them against the current configuration. - """ - param_name, param_value = parse_raw_parameter(parameter) - valid_param_name = valid_param_value = False - - command = self.commands.get(command_name) - if not command: - raise AttributeError("Unknown command: '{}'".format(command_name)) - - # Iterate over all available parameters until we have a match. - for param in command.params: - if self._same_parameter(param_name, param): - valid_param_name = True - # This is a non-empty list - if param.values: - # We check if the value is allowed in the configuration - valid_param_value = param_value in param.values - else: - # In this case we don't validate the value and accept - # whatever we have set. - valid_param_value = True - break - - return valid_param_name and valid_param_value - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Backend): - return False - - return ( - self.name == other.name - and self.description == other.description - and self.commands == other.commands - ) - - def __repr__(self) -> str: - """Represent the Backend instance by its name.""" - return self.name - - def _parse_commands_and_params(self, config: BaseBackendConfig) -> None: - """Parse commands and user parameters.""" - self.commands: Dict[str, Command] = {} - - commands = config.get("commands") - if commands: - params = config.get("user_params") - - for command_name in commands.keys(): - command_params = self._parse_params(params, command_name) - command_strings = [ - self._substitute_variables(cmd) - for cmd in commands.get(command_name, []) - ] - self.commands[command_name] = Command(command_strings, command_params) - - def _substitute_variables(self, str_val: str) -> str: - """Substitute variables in string. - - Variables is being substituted at backend's creation stage because - they could contain references to other params which will be - resolved later. - """ - if not str_val: - return str_val - - var_pattern: Final[Pattern] = re.compile(r"{variables:(?P<var_name>\w+)}") - - def var_value(match: Match) -> str: - var_name = match["var_name"] - if var_name not in self.variables: - raise ConfigurationException("Unknown variable {}".format(var_name)) - - return self.variables[var_name] - - return var_pattern.sub(var_value, str_val) # type: ignore - - @classmethod - def _parse_params( - cls, params: Optional[UserParamsConfig], command: str - ) -> List["Param"]: - if not params: - return [] - - return [cls._parse_param(p) for p in params.get(command, [])] - - @classmethod - def _parse_param(cls, param: UserParamConfig) -> "Param": - """Parse a single parameter.""" - name = param.get("name") - if name is not None and not name: - raise ConfigurationException("Parameter has an empty 'name' attribute.") - values = param.get("values", None) - default_value = param.get("default_value", None) - description = param.get("description", "") - alias = param.get("alias") - - return Param( - name=name, - description=description, - values=values, - default_value=default_value, - alias=alias, - ) - - def _get_command_details(self) -> Dict: - command_details = { - command_name: command.get_details() - for command_name, command in self.commands.items() - } - return command_details - - def _get_user_param_value( - self, user_params: List[str], param: "Param" - ) -> Optional[str]: - """Get the user-specified value of a parameter.""" - for user_param in user_params: - user_param_name, user_param_value = parse_raw_parameter(user_param) - if user_param_name == param.name: - warn_message = ( - "The direct use of parameter name is deprecated" - " and might be removed in the future.\n" - f"Please use alias '{param.alias}' instead of " - "'{user_param_name}' to provide the parameter." - ) - logging.warning(warn_message) - - if self._same_parameter(user_param_name, param): - return user_param_value - - return None - - @staticmethod - def _same_parameter(user_param_name_or_alias: str, param: "Param") -> bool: - """Compare user parameter name with param name or alias.""" - # Strip the "=" sign in the param_name. This is needed just for - # comparison with the parameters passed by the user. - # The equal sign needs to be honoured when re-building the - # parameter back. - param_name = None if not param.name else param.name.rstrip("=") - return user_param_name_or_alias in [param_name, param.alias] - - def resolved_parameters( - self, command_name: str, user_params: List[str] - ) -> List[Tuple[Optional[str], "Param"]]: - """Return list of parameters with values.""" - result: List[Tuple[Optional[str], "Param"]] = [] - command = self.commands.get(command_name) - if not command: - return result - - for param in command.params: - value = self._get_user_param_value(user_params, param) - if not value: - value = param.default_value - result.append((value, param)) - - return result - - def build_command( - self, - command_name: str, - user_params: List[str], - param_resolver: Callable[[str, str, List[Tuple[Optional[str], "Param"]]], str], - ) -> List[str]: - """ - Return a list of executable command strings. - - Given a command and associated parameters, returns a list of executable command - strings. - """ - command = self.commands.get(command_name) - if not command: - raise ConfigurationException( - "Command '{}' could not be found.".format(command_name) - ) - - commands_to_run = [] - - params_values = self.resolved_parameters(command_name, user_params) - for cmd_str in command.command_strings: - cmd_str = resolve_all_parameters( - cmd_str, param_resolver, command_name, params_values - ) - commands_to_run.append(cmd_str) - - return commands_to_run - - -class Param: - """Class for representing a generic application parameter.""" - - def __init__( # pylint: disable=too-many-arguments - self, - name: Optional[str], - description: str, - values: Optional[List[str]] = None, - default_value: Optional[str] = None, - alias: Optional[str] = None, - ) -> None: - """Construct a Param instance.""" - if not name and not alias: - raise ConfigurationException( - "Either name, alias or both must be set to identify a parameter." - ) - self.name = name - self.values = values - self.description = description - self.default_value = default_value - self.alias = alias - - def get_details(self) -> Dict: - """Return a dictionary with all relevant information of a Param.""" - return {key: value for key, value in self.__dict__.items() if value} - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Param): - return False - - return ( - self.name == other.name - and self.values == other.values - and self.default_value == other.default_value - and self.description == other.description - ) - - -class Command: - """Class for representing a command.""" - - def __init__( - self, command_strings: List[str], params: Optional[List[Param]] = None - ) -> None: - """Construct a Command instance.""" - self.command_strings = command_strings - - if params: - self.params = params - else: - self.params = [] - - self._validate() - - def _validate(self) -> None: - """Validate command.""" - if not self.params: - return - - aliases = [param.alias for param in self.params if param.alias is not None] - repeated_aliases = [ - alias for alias, count in Counter(aliases).items() if count > 1 - ] - - if repeated_aliases: - raise ConfigurationException( - "Non unique aliases {}".format(", ".join(repeated_aliases)) - ) - - both_name_and_alias = [ - param.name - for param in self.params - if param.name in aliases and param.name != param.alias - ] - if both_name_and_alias: - raise ConfigurationException( - "Aliases {} could not be used as parameter name".format( - ", ".join(both_name_and_alias) - ) - ) - - def get_details(self) -> Dict: - """Return a dictionary with all relevant information of a Command.""" - output = { - "command_strings": self.command_strings, - "user_params": [param.get_details() for param in self.params], - } - return output - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Command): - return False - - return ( - self.command_strings == other.command_strings - and self.params == other.params - ) - - -def resolve_all_parameters( - str_val: str, - param_resolver: Callable[[str, str, List[Tuple[Optional[str], Param]]], str], - command_name: Optional[str] = None, - params_values: Optional[List[Tuple[Optional[str], Param]]] = None, -) -> str: - """Resolve all parameters in the string.""" - if not str_val: - return str_val - - param_pattern: Final[Pattern] = re.compile(r"{(?P<param_name>[\w.:]+)}") - while param_pattern.findall(str_val): - str_val = param_pattern.sub( - lambda m: param_resolver( - m["param_name"], command_name or "", params_values or [] - ), - str_val, - ) - return str_val - - -def load_application_or_tool_configs( - config: Any, - config_type: Type[Any], - is_system_required: bool = True, -) -> Any: - """Get one config for each system supported by the application/tool. - - The configuration could contain different parameters/commands for different - supported systems. For each supported system this function will return separate - config with appropriate configuration. - """ - merged_configs = [] - supported_systems: Optional[List[NamedExecutionConfig]] = config.get( - "supported_systems" - ) - if not supported_systems: - if is_system_required: - raise ConfigurationException("No supported systems definition provided") - # Create an empty system to be used in the parsing below - supported_systems = [cast(NamedExecutionConfig, {})] - - default_user_params = config.get("user_params", {}) - - def merge_config(system: NamedExecutionConfig) -> Any: - system_name = system.get("name") - if not system_name and is_system_required: - raise ConfigurationException( - "Unable to read supported system definition, name is missed" - ) - - merged_config = config_type(**config) - merged_config["supported_systems"] = [system_name] if system_name else [] - # merge default configuration and specific to the system - merged_config["commands"] = { - **config.get("commands", {}), - **system.get("commands", {}), - } - - params = {} - tool_user_params = system.get("user_params", {}) - command_names = tool_user_params.keys() | default_user_params.keys() - for command_name in command_names: - if command_name not in merged_config["commands"]: - continue - - params_default = default_user_params.get(command_name, []) - params_tool = tool_user_params.get(command_name, []) - if not params_default or not params_tool: - params[command_name] = params_tool or params_default - if params_default and params_tool: - if any(not p.get("alias") for p in params_default): - raise ConfigurationException( - "Default parameters for command {} should have aliases".format( - command_name - ) - ) - if any(not p.get("alias") for p in params_tool): - raise ConfigurationException( - "{} parameters for command {} should have aliases".format( - system_name, command_name - ) - ) - - merged_by_alias = { - **{p.get("alias"): p for p in params_default}, - **{p.get("alias"): p for p in params_tool}, - } - params[command_name] = list(merged_by_alias.values()) - - merged_config["user_params"] = params - merged_config["build_dir"] = system.get("build_dir", config.get("build_dir")) - merged_config["lock"] = system.get("lock", config.get("lock", False)) - merged_config["variables"] = { - **config.get("variables", {}), - **system.get("variables", {}), - } - return merged_config - - merged_configs = [merge_config(system) for system in supported_systems] - - return merged_configs diff --git a/src/aiet/backend/config.py b/src/aiet/backend/config.py deleted file mode 100644 index dd42012..0000000 --- a/src/aiet/backend/config.py +++ /dev/null @@ -1,107 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain definition of backend configuration.""" -from pathlib import Path -from typing import Dict -from typing import List -from typing import Literal -from typing import Optional -from typing import Tuple -from typing import TypedDict -from typing import Union - - -class UserParamConfig(TypedDict, total=False): - """User parameter configuration.""" - - name: Optional[str] - default_value: str - values: List[str] - description: str - alias: str - - -UserParamsConfig = Dict[str, List[UserParamConfig]] - - -class ExecutionConfig(TypedDict, total=False): - """Execution configuration.""" - - commands: Dict[str, List[str]] - user_params: UserParamsConfig - build_dir: str - variables: Dict[str, str] - lock: bool - - -class NamedExecutionConfig(ExecutionConfig): - """Execution configuration with name.""" - - name: str - - -class BaseBackendConfig(ExecutionConfig, total=False): - """Base backend configuration.""" - - name: str - description: str - config_location: Path - annotations: Dict[str, Union[str, List[str]]] - - -class ApplicationConfig(BaseBackendConfig, total=False): - """Application configuration.""" - - supported_systems: List[str] - deploy_data: List[Tuple[str, str]] - - -class ExtendedApplicationConfig(BaseBackendConfig, total=False): - """Extended application configuration.""" - - supported_systems: List[NamedExecutionConfig] - deploy_data: List[Tuple[str, str]] - - -class ProtocolConfig(TypedDict, total=False): - """Protocol config.""" - - protocol: Literal["local", "ssh"] - - -class SSHConfig(ProtocolConfig, total=False): - """SSH configuration.""" - - username: str - password: str - hostname: str - port: str - - -class LocalProtocolConfig(ProtocolConfig, total=False): - """Local protocol config.""" - - -class SystemConfig(BaseBackendConfig, total=False): - """System configuration.""" - - data_transfer: Union[SSHConfig, LocalProtocolConfig] - reporting: Dict[str, Dict] - - -class ToolConfig(BaseBackendConfig, total=False): - """Tool configuration.""" - - supported_systems: List[str] - - -class ExtendedToolConfig(BaseBackendConfig, total=False): - """Extended tool configuration.""" - - supported_systems: List[NamedExecutionConfig] - - -BackendItemConfig = Union[ApplicationConfig, SystemConfig, ToolConfig] -BackendConfig = Union[ - List[ExtendedApplicationConfig], List[SystemConfig], List[ToolConfig] -] diff --git a/src/aiet/backend/controller.py b/src/aiet/backend/controller.py deleted file mode 100644 index 2650902..0000000 --- a/src/aiet/backend/controller.py +++ /dev/null @@ -1,134 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Controller backend module.""" -import time -from pathlib import Path -from typing import List -from typing import Optional -from typing import Tuple - -import psutil -import sh - -from aiet.backend.common import ConfigurationException -from aiet.utils.fs import read_file_as_string -from aiet.utils.proc import execute_command -from aiet.utils.proc import get_stdout_stderr_paths -from aiet.utils.proc import read_process_info -from aiet.utils.proc import save_process_info -from aiet.utils.proc import terminate_command -from aiet.utils.proc import terminate_external_process - - -class SystemController: - """System controller class.""" - - def __init__(self) -> None: - """Create new instance of service controller.""" - self.cmd: Optional[sh.RunningCommand] = None - self.out_path: Optional[Path] = None - self.err_path: Optional[Path] = None - - def before_start(self) -> None: - """Run actions before system start.""" - - def after_start(self) -> None: - """Run actions after system start.""" - - def start(self, commands: List[str], cwd: Path) -> None: - """Start system.""" - if not isinstance(cwd, Path) or not cwd.is_dir(): - raise ConfigurationException("Wrong working directory {}".format(cwd)) - - if len(commands) != 1: - raise ConfigurationException("System should have only one command to run") - - startup_command = commands[0] - if not startup_command: - raise ConfigurationException("No startup command provided") - - self.before_start() - - self.out_path, self.err_path = get_stdout_stderr_paths(startup_command) - - self.cmd = execute_command( - startup_command, - cwd, - bg=True, - out=str(self.out_path), - err=str(self.err_path), - ) - - self.after_start() - - def stop( - self, wait: bool = False, wait_period: float = 0.5, number_of_attempts: int = 20 - ) -> None: - """Stop system.""" - if self.cmd is not None and self.is_running(): - terminate_command(self.cmd, wait, wait_period, number_of_attempts) - - def is_running(self) -> bool: - """Check if underlying process is running.""" - return self.cmd is not None and self.cmd.is_alive() - - def get_output(self) -> Tuple[str, str]: - """Return application output.""" - if self.cmd is None or self.out_path is None or self.err_path is None: - return ("", "") - - return (read_file_as_string(self.out_path), read_file_as_string(self.err_path)) - - -class SystemControllerSingleInstance(SystemController): - """System controller with support of system's single instance.""" - - def __init__(self, pid_file_path: Optional[Path] = None) -> None: - """Create new instance of the service controller.""" - super().__init__() - self.pid_file_path = pid_file_path - - def before_start(self) -> None: - """Run actions before system start.""" - self._check_if_previous_instance_is_running() - - def after_start(self) -> None: - """Run actions after system start.""" - self._save_process_info() - - def _check_if_previous_instance_is_running(self) -> None: - """Check if another instance of the system is running.""" - process_info = read_process_info(self._pid_file()) - - for item in process_info: - try: - process = psutil.Process(item.pid) - same_process = ( - process.name() == item.name - and process.exe() == item.executable - and process.cwd() == item.cwd - ) - if same_process: - print( - "Stopping previous instance of the system [{}]".format(item.pid) - ) - terminate_external_process(process) - except psutil.NoSuchProcess: - pass - - def _save_process_info(self, wait_period: float = 2) -> None: - """Save information about system's processes.""" - if self.cmd is None or not self.is_running(): - return - - # give some time for the system to start - time.sleep(wait_period) - - save_process_info(self.cmd.process.pid, self._pid_file()) - - def _pid_file(self) -> Path: - """Return path to file which is used for saving process info.""" - if not self.pid_file_path: - raise Exception("No pid file path presented") - - return self.pid_file_path diff --git a/src/aiet/backend/execution.py b/src/aiet/backend/execution.py deleted file mode 100644 index 1653ee2..0000000 --- a/src/aiet/backend/execution.py +++ /dev/null @@ -1,859 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Application execution module.""" -import itertools -import json -import random -import re -import string -import sys -import time -import warnings -from collections import defaultdict -from contextlib import contextmanager -from contextlib import ExitStack -from pathlib import Path -from typing import Any -from typing import Callable -from typing import cast -from typing import ContextManager -from typing import Dict -from typing import Generator -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypedDict -from typing import Union - -from filelock import FileLock -from filelock import Timeout - -from aiet.backend.application import Application -from aiet.backend.application import get_application -from aiet.backend.common import Backend -from aiet.backend.common import ConfigurationException -from aiet.backend.common import DataPaths -from aiet.backend.common import Param -from aiet.backend.common import parse_raw_parameter -from aiet.backend.common import resolve_all_parameters -from aiet.backend.output_parser import Base64OutputParser -from aiet.backend.output_parser import OutputParser -from aiet.backend.output_parser import RegexOutputParser -from aiet.backend.system import ControlledSystem -from aiet.backend.system import get_system -from aiet.backend.system import StandaloneSystem -from aiet.backend.system import System -from aiet.backend.tool import get_tool -from aiet.backend.tool import Tool -from aiet.utils.fs import recreate_directory -from aiet.utils.fs import remove_directory -from aiet.utils.fs import valid_for_filename -from aiet.utils.proc import run_and_wait - - -class AnotherInstanceIsRunningException(Exception): - """Concurrent execution error.""" - - -class ConnectionException(Exception): - """Connection exception.""" - - -class ExecutionParams(TypedDict, total=False): - """Execution parameters.""" - - disable_locking: bool - unique_build_dir: bool - - -class ExecutionContext: - """Command execution context.""" - - # pylint: disable=too-many-arguments,too-many-instance-attributes - def __init__( - self, - app: Union[Application, Tool], - app_params: List[str], - system: Optional[System], - system_params: List[str], - custom_deploy_data: Optional[List[DataPaths]] = None, - execution_params: Optional[ExecutionParams] = None, - report_file: Optional[Path] = None, - ): - """Init execution context.""" - self.app = app - self.app_params = app_params - self.custom_deploy_data = custom_deploy_data or [] - self.system = system - self.system_params = system_params - self.execution_params = execution_params or ExecutionParams() - self.report_file = report_file - - self.reporter: Optional[Reporter] - if self.report_file: - # Create reporter with output parsers - parsers: List[OutputParser] = [] - if system and system.reporting: - # Add RegexOutputParser, if it is configured in the system - parsers.append(RegexOutputParser("system", system.reporting["regex"])) - # Add Base64 parser for applications - parsers.append(Base64OutputParser("application")) - self.reporter = Reporter(parsers=parsers) - else: - self.reporter = None # No reporter needed. - - self.param_resolver = ParamResolver(self) - self._resolved_build_dir: Optional[Path] = None - - @property - def is_deploy_needed(self) -> bool: - """Check if application requires data deployment.""" - if isinstance(self.app, Application): - return ( - len(self.app.get_deploy_data()) > 0 or len(self.custom_deploy_data) > 0 - ) - return False - - @property - def is_locking_required(self) -> bool: - """Return true if any form of locking required.""" - return not self._disable_locking() and ( - self.app.lock or (self.system is not None and self.system.lock) - ) - - @property - def is_build_required(self) -> bool: - """Return true if application build required.""" - return "build" in self.app.commands - - @property - def is_unique_build_dir_required(self) -> bool: - """Return true if unique build dir required.""" - return self.execution_params.get("unique_build_dir", False) - - def build_dir(self) -> Path: - """Return resolved application build dir.""" - if self._resolved_build_dir is not None: - return self._resolved_build_dir - - if ( - not isinstance(self.app.config_location, Path) - or not self.app.config_location.is_dir() - ): - raise ConfigurationException( - "Application {} has wrong config location".format(self.app.name) - ) - - _build_dir = self.app.build_dir - if _build_dir: - _build_dir = resolve_all_parameters(_build_dir, self.param_resolver) - - if not _build_dir: - raise ConfigurationException( - "No build directory defined for the app {}".format(self.app.name) - ) - - if self.is_unique_build_dir_required: - random_suffix = "".join( - random.choices(string.ascii_lowercase + string.digits, k=7) - ) - _build_dir = "{}_{}".format(_build_dir, random_suffix) - - self._resolved_build_dir = self.app.config_location / _build_dir - return self._resolved_build_dir - - def _disable_locking(self) -> bool: - """Return true if locking should be disabled.""" - return self.execution_params.get("disable_locking", False) - - -class ParamResolver: - """Parameter resolver.""" - - def __init__(self, context: ExecutionContext): - """Init parameter resolver.""" - self.ctx = context - - @staticmethod - def resolve_user_params( - cmd_name: Optional[str], - index_or_alias: str, - resolved_params: Optional[List[Tuple[Optional[str], Param]]], - ) -> str: - """Resolve user params.""" - if not cmd_name or resolved_params is None: - raise ConfigurationException("Unable to resolve user params") - - param_value: Optional[str] = None - param: Optional[Param] = None - - if index_or_alias.isnumeric(): - i = int(index_or_alias) - if i not in range(len(resolved_params)): - raise ConfigurationException( - "Invalid index {} for user params of command {}".format(i, cmd_name) - ) - param_value, param = resolved_params[i] - else: - for val, par in resolved_params: - if par.alias == index_or_alias: - param_value, param = val, par - break - - if param is None: - raise ConfigurationException( - "No user parameter for command '{}' with alias '{}'.".format( - cmd_name, index_or_alias - ) - ) - - if param_value: - # We need to handle to cases of parameters here: - # 1) Optional parameters (non-positional with a name and value) - # 2) Positional parameters (value only, no name needed) - # Default to empty strings for positional arguments - param_name = "" - separator = "" - if param.name is not None: - # A valid param name means we have an optional/non-positional argument: - # The separator is an empty string in case the param_name - # has an equal sign as we have to honour it. - # If the parameter doesn't end with an equal sign then a - # space character is injected to split the parameter name - # and its value - param_name = param.name - separator = "" if param.name.endswith("=") else " " - - return "{param_name}{separator}{param_value}".format( - param_name=param_name, - separator=separator, - param_value=param_value, - ) - - if param.name is None: - raise ConfigurationException( - "Missing user parameter with alias '{}' for command '{}'.".format( - index_or_alias, cmd_name - ) - ) - - return param.name # flag: just return the parameter name - - def resolve_commands_and_params( - self, backend_type: str, cmd_name: str, return_params: bool, index_or_alias: str - ) -> str: - """Resolve command or command's param value.""" - if backend_type == "system": - backend = cast(Backend, self.ctx.system) - backend_params = self.ctx.system_params - else: # Application or Tool backend - backend = cast(Backend, self.ctx.app) - backend_params = self.ctx.app_params - - if cmd_name not in backend.commands: - raise ConfigurationException("Command {} not found".format(cmd_name)) - - if return_params: - params = backend.resolved_parameters(cmd_name, backend_params) - if index_or_alias.isnumeric(): - i = int(index_or_alias) - if i not in range(len(params)): - raise ConfigurationException( - "Invalid parameter index {} for command {}".format(i, cmd_name) - ) - - param_value = params[i][0] - else: - param_value = None - for value, param in params: - if param.alias == index_or_alias: - param_value = value - break - - if not param_value: - raise ConfigurationException( - ( - "No value for parameter with index or alias {} of command {}" - ).format(index_or_alias, cmd_name) - ) - return param_value - - if not index_or_alias.isnumeric(): - raise ConfigurationException("Bad command index {}".format(index_or_alias)) - - i = int(index_or_alias) - commands = backend.build_command(cmd_name, backend_params, self.param_resolver) - if i not in range(len(commands)): - raise ConfigurationException( - "Invalid index {} for command {}".format(i, cmd_name) - ) - - return commands[i] - - def resolve_variables(self, backend_type: str, var_name: str) -> str: - """Resolve variable value.""" - if backend_type == "system": - backend = cast(Backend, self.ctx.system) - else: # Application or Tool backend - backend = cast(Backend, self.ctx.app) - - if var_name not in backend.variables: - raise ConfigurationException("Unknown variable {}".format(var_name)) - - return backend.variables[var_name] - - def param_matcher( - self, - param_name: str, - cmd_name: Optional[str], - resolved_params: Optional[List[Tuple[Optional[str], Param]]], - ) -> str: - """Regexp to resolve a param from the param_name.""" - # this pattern supports parameter names like "application.commands.run:0" and - # "system.commands.run.params:0" - # Note: 'software' is included for backward compatibility. - commands_and_params_match = re.match( - r"(?P<type>application|software|tool|system)[.]commands[.]" - r"(?P<name>\w+)" - r"(?P<params>[.]params|)[:]" - r"(?P<index_or_alias>\w+)", - param_name, - ) - - if commands_and_params_match: - backend_type, cmd_name, return_params, index_or_alias = ( - commands_and_params_match["type"], - commands_and_params_match["name"], - commands_and_params_match["params"], - commands_and_params_match["index_or_alias"], - ) - return self.resolve_commands_and_params( - backend_type, cmd_name, bool(return_params), index_or_alias - ) - - # Note: 'software' is included for backward compatibility. - variables_match = re.match( - r"(?P<type>application|software|tool|system)[.]variables:(?P<var_name>\w+)", - param_name, - ) - if variables_match: - backend_type, var_name = ( - variables_match["type"], - variables_match["var_name"], - ) - return self.resolve_variables(backend_type, var_name) - - user_params_match = re.match(r"user_params:(?P<index_or_alias>\w+)", param_name) - if user_params_match: - index_or_alias = user_params_match["index_or_alias"] - return self.resolve_user_params(cmd_name, index_or_alias, resolved_params) - - raise ConfigurationException( - "Unable to resolve parameter {}".format(param_name) - ) - - def param_resolver( - self, - param_name: str, - cmd_name: Optional[str] = None, - resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, - ) -> str: - """Resolve parameter value based on current execution context.""" - # Note: 'software.*' is included for backward compatibility. - resolved_param = None - if param_name in ["application.name", "tool.name", "software.name"]: - resolved_param = self.ctx.app.name - elif param_name in [ - "application.description", - "tool.description", - "software.description", - ]: - resolved_param = self.ctx.app.description - elif self.ctx.app.config_location and ( - param_name - in ["application.config_dir", "tool.config_dir", "software.config_dir"] - ): - resolved_param = str(self.ctx.app.config_location.absolute()) - elif self.ctx.app.build_dir and ( - param_name - in ["application.build_dir", "tool.build_dir", "software.build_dir"] - ): - resolved_param = str(self.ctx.build_dir().absolute()) - elif self.ctx.system is not None: - if param_name == "system.name": - resolved_param = self.ctx.system.name - elif param_name == "system.description": - resolved_param = self.ctx.system.description - elif param_name == "system.config_dir" and self.ctx.system.config_location: - resolved_param = str(self.ctx.system.config_location.absolute()) - - if not resolved_param: - resolved_param = self.param_matcher(param_name, cmd_name, resolved_params) - return resolved_param - - def __call__( - self, - param_name: str, - cmd_name: Optional[str] = None, - resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, - ) -> str: - """Resolve provided parameter.""" - return self.param_resolver(param_name, cmd_name, resolved_params) - - -class Reporter: - """Report metrics from the simulation output.""" - - def __init__(self, parsers: Optional[List[OutputParser]] = None) -> None: - """Create an empty reporter (i.e. no parsers registered).""" - self.parsers: List[OutputParser] = parsers if parsers is not None else [] - self._report: Dict[str, Any] = defaultdict(lambda: defaultdict(dict)) - - def parse(self, output: bytearray) -> None: - """Parse output and append parsed metrics to internal report dict.""" - for parser in self.parsers: - # Merge metrics from different parsers (do not overwrite) - self._report[parser.name]["metrics"].update(parser(output)) - - def get_filtered_output(self, output: bytearray) -> bytearray: - """Filter the output according to each parser.""" - for parser in self.parsers: - output = parser.filter_out_parsed_content(output) - return output - - def report(self, ctx: ExecutionContext) -> Dict[str, Any]: - """Add static simulation info to parsed data and return the report.""" - report: Dict[str, Any] = defaultdict(dict) - # Add static simulation info - report.update(self._static_info(ctx)) - # Add metrics parsed from the output - for key, val in self._report.items(): - report[key].update(val) - return report - - @staticmethod - def save(report: Dict[str, Any], report_file: Path) -> None: - """Save the report to a JSON file.""" - with open(report_file, "w", encoding="utf-8") as file: - json.dump(report, file, indent=4) - - @staticmethod - def _compute_all_params(cli_params: List[str], backend: Backend) -> Dict[str, str]: - """ - Build a dict of all parameters, {name:value}. - - Param values taken from command line if specified, defaults otherwise. - """ - # map of params passed from the cli ["p1=v1","p2=v2"] -> {"p1":"v1", "p2":"v2"} - app_params_map = dict(parse_raw_parameter(expr) for expr in cli_params) - - # a map of params declared in the application, with values taken from the CLI, - # defaults otherwise - all_params = { - (p.alias or p.name): app_params_map.get( - cast(str, p.name), cast(str, p.default_value) - ) - for cmd in backend.commands.values() - for p in cmd.params - } - return cast(Dict[str, str], all_params) - - @staticmethod - def _static_info(ctx: ExecutionContext) -> Dict[str, Any]: - """Extract static simulation information from the context.""" - if ctx.system is None: - raise ValueError("No system available to report.") - - info = { - "system": { - "name": ctx.system.name, - "params": Reporter._compute_all_params(ctx.system_params, ctx.system), - }, - "application": { - "name": ctx.app.name, - "params": Reporter._compute_all_params(ctx.app_params, ctx.app), - }, - } - return info - - -def validate_parameters( - backend: Backend, command_names: List[str], params: List[str] -) -> None: - """Check parameters passed to backend.""" - for param in params: - acceptable = any( - backend.validate_parameter(command_name, param) - for command_name in command_names - if command_name in backend.commands - ) - - if not acceptable: - backend_type = "System" if isinstance(backend, System) else "Application" - raise ValueError( - "{} parameter '{}' not valid for command '{}'".format( - backend_type, param, " or ".join(command_names) - ) - ) - - -def get_application_by_name_and_system( - application_name: str, system_name: str -) -> Application: - """Get application.""" - applications = get_application(application_name, system_name) - if not applications: - raise ValueError( - "Application '{}' doesn't support the system '{}'".format( - application_name, system_name - ) - ) - - if len(applications) != 1: - raise ValueError( - "Error during getting application {} for the system {}".format( - application_name, system_name - ) - ) - - return applications[0] - - -def get_application_and_system( - application_name: str, system_name: str -) -> Tuple[Application, System]: - """Return application and system by provided names.""" - system = get_system(system_name) - if not system: - raise ValueError("System {} is not found".format(system_name)) - - application = get_application_by_name_and_system(application_name, system_name) - - return application, system - - -def execute_application_command( # pylint: disable=too-many-arguments - command_name: str, - application_name: str, - application_params: List[str], - system_name: str, - system_params: List[str], - custom_deploy_data: List[DataPaths], -) -> None: - """Execute application command. - - .. deprecated:: 21.12 - """ - warnings.warn( - "Use 'run_application()' instead. Use of 'execute_application_command()' is " - "deprecated and might be removed in a future release.", - DeprecationWarning, - ) - - if command_name not in ["build", "run"]: - raise ConfigurationException("Unsupported command {}".format(command_name)) - - application, system = get_application_and_system(application_name, system_name) - validate_parameters(application, [command_name], application_params) - validate_parameters(system, [command_name], system_params) - - ctx = ExecutionContext( - app=application, - app_params=application_params, - system=system, - system_params=system_params, - custom_deploy_data=custom_deploy_data, - ) - - if command_name == "run": - execute_application_command_run(ctx) - else: - execute_application_command_build(ctx) - - -# pylint: disable=too-many-arguments -def run_application( - application_name: str, - application_params: List[str], - system_name: str, - system_params: List[str], - custom_deploy_data: List[DataPaths], - report_file: Optional[Path] = None, -) -> None: - """Run application on the provided system.""" - application, system = get_application_and_system(application_name, system_name) - validate_parameters(application, ["build", "run"], application_params) - validate_parameters(system, ["build", "run"], system_params) - - execution_params = ExecutionParams() - if isinstance(system, StandaloneSystem): - execution_params["disable_locking"] = True - execution_params["unique_build_dir"] = True - - ctx = ExecutionContext( - app=application, - app_params=application_params, - system=system, - system_params=system_params, - custom_deploy_data=custom_deploy_data, - execution_params=execution_params, - report_file=report_file, - ) - - with build_dir_manager(ctx): - if ctx.is_build_required: - execute_application_command_build(ctx) - - execute_application_command_run(ctx) - - -def execute_application_command_build(ctx: ExecutionContext) -> None: - """Execute application command 'build'.""" - with ExitStack() as context_stack: - for manager in get_context_managers("build", ctx): - context_stack.enter_context(manager(ctx)) - - build_dir = ctx.build_dir() - recreate_directory(build_dir) - - build_commands = ctx.app.build_command( - "build", ctx.app_params, ctx.param_resolver - ) - execute_commands_locally(build_commands, build_dir) - - -def execute_commands_locally(commands: List[str], cwd: Path) -> None: - """Execute list of commands locally.""" - for command in commands: - print("Running: {}".format(command)) - run_and_wait( - command, cwd, terminate_on_error=True, out=sys.stdout, err=sys.stderr - ) - - -def execute_application_command_run(ctx: ExecutionContext) -> None: - """Execute application command.""" - assert ctx.system is not None, "System must be provided." - if ctx.is_deploy_needed and not ctx.system.supports_deploy: - raise ConfigurationException( - "System {} does not support data deploy".format(ctx.system.name) - ) - - with ExitStack() as context_stack: - for manager in get_context_managers("run", ctx): - context_stack.enter_context(manager(ctx)) - - print("Generating commands to execute") - commands_to_run = build_run_commands(ctx) - - if ctx.system.connectable: - establish_connection(ctx) - - if ctx.system.supports_deploy: - deploy_data(ctx) - - for command in commands_to_run: - print("Running: {}".format(command)) - exit_code, std_output, std_err = ctx.system.run(command) - - if exit_code != 0: - print("Application exited with exit code {}".format(exit_code)) - - if ctx.reporter: - ctx.reporter.parse(std_output) - std_output = ctx.reporter.get_filtered_output(std_output) - - print(std_output.decode("utf8"), end="") - print(std_err.decode("utf8"), end="") - - if ctx.reporter: - report = ctx.reporter.report(ctx) - ctx.reporter.save(report, cast(Path, ctx.report_file)) - - -def establish_connection( - ctx: ExecutionContext, retries: int = 90, interval: float = 15.0 -) -> None: - """Establish connection with the system.""" - assert ctx.system is not None, "System is required." - host, port = ctx.system.connection_details() - print( - "Trying to establish connection with '{}:{}' - " - "{} retries every {} seconds ".format(host, port, retries, interval), - end="", - ) - - try: - for _ in range(retries): - print(".", end="", flush=True) - - if ctx.system.establish_connection(): - break - - if isinstance(ctx.system, ControlledSystem) and not ctx.system.is_running(): - print( - "\n\n---------- {} execution failed ----------".format( - ctx.system.name - ) - ) - stdout, stderr = ctx.system.get_output() - print(stdout) - print(stderr) - - raise Exception("System is not running") - - wait(interval) - else: - raise ConnectionException("Couldn't connect to '{}:{}'.".format(host, port)) - finally: - print() - - -def wait(interval: float) -> None: - """Wait for a period of time.""" - time.sleep(interval) - - -def deploy_data(ctx: ExecutionContext) -> None: - """Deploy data to the system.""" - if isinstance(ctx.app, Application): - # Only application can deploy data (tools can not) - assert ctx.system is not None, "System is required." - for item in itertools.chain(ctx.app.get_deploy_data(), ctx.custom_deploy_data): - print("Deploying {} onto {}".format(item.src, item.dst)) - ctx.system.deploy(item.src, item.dst) - - -def build_run_commands(ctx: ExecutionContext) -> List[str]: - """Build commands to run application.""" - if isinstance(ctx.system, StandaloneSystem): - return ctx.system.build_command("run", ctx.system_params, ctx.param_resolver) - - return ctx.app.build_command("run", ctx.app_params, ctx.param_resolver) - - -@contextmanager -def controlled_system_manager(ctx: ExecutionContext) -> Generator[None, None, None]: - """Context manager used for system initialisation before run.""" - system = cast(ControlledSystem, ctx.system) - commands = system.build_command("run", ctx.system_params, ctx.param_resolver) - pid_file_path: Optional[Path] = None - if ctx.is_locking_required: - file_lock_path = get_file_lock_path(ctx) - pid_file_path = file_lock_path.parent / "{}.pid".format(file_lock_path.stem) - - system.start(commands, ctx.is_locking_required, pid_file_path) - try: - yield - finally: - print("Shutting down sequence...") - print("Stopping {}... (It could take few seconds)".format(system.name)) - system.stop(wait=True) - print("{} stopped successfully.".format(system.name)) - - -@contextmanager -def lock_execution_manager(ctx: ExecutionContext) -> Generator[None, None, None]: - """Lock execution manager.""" - file_lock_path = get_file_lock_path(ctx) - file_lock = FileLock(str(file_lock_path)) - - try: - file_lock.acquire(timeout=1) - except Timeout as error: - raise AnotherInstanceIsRunningException() from error - - try: - yield - finally: - file_lock.release() - - -def get_file_lock_path(ctx: ExecutionContext, lock_dir: Path = Path("/tmp")) -> Path: - """Get file lock path.""" - lock_modules = [] - if ctx.app.lock: - lock_modules.append(ctx.app.name) - if ctx.system is not None and ctx.system.lock: - lock_modules.append(ctx.system.name) - lock_filename = "" - if lock_modules: - lock_filename = "_".join(["middleware"] + lock_modules) + ".lock" - - if lock_filename: - lock_filename = resolve_all_parameters(lock_filename, ctx.param_resolver) - lock_filename = valid_for_filename(lock_filename) - - if not lock_filename: - raise ConfigurationException("No filename for lock provided") - - if not isinstance(lock_dir, Path) or not lock_dir.is_dir(): - raise ConfigurationException( - "Invalid directory {} for lock files provided".format(lock_dir) - ) - - return lock_dir / lock_filename - - -@contextmanager -def build_dir_manager(ctx: ExecutionContext) -> Generator[None, None, None]: - """Build directory manager.""" - try: - yield - finally: - if ( - ctx.is_build_required - and ctx.is_unique_build_dir_required - and ctx.build_dir().is_dir() - ): - remove_directory(ctx.build_dir()) - - -def get_context_managers( - command_name: str, ctx: ExecutionContext -) -> Sequence[Callable[[ExecutionContext], ContextManager[None]]]: - """Get context manager for the system.""" - managers = [] - - if ctx.is_locking_required: - managers.append(lock_execution_manager) - - if command_name == "run": - if isinstance(ctx.system, ControlledSystem): - managers.append(controlled_system_manager) - - return managers - - -def get_tool_by_system(tool_name: str, system_name: Optional[str]) -> Tool: - """Return tool (optionally by provided system name.""" - tools = get_tool(tool_name, system_name) - if not tools: - raise ConfigurationException( - "Tool '{}' not found or doesn't support the system '{}'".format( - tool_name, system_name - ) - ) - if len(tools) != 1: - raise ConfigurationException( - "Please specify the system for tool {}.".format(tool_name) - ) - tool = tools[0] - - return tool - - -def execute_tool_command( - tool_name: str, - tool_params: List[str], - system_name: Optional[str] = None, -) -> None: - """Execute the tool command locally calling the 'run' command.""" - tool = get_tool_by_system(tool_name, system_name) - ctx = ExecutionContext( - app=tool, app_params=tool_params, system=None, system_params=[] - ) - commands = tool.build_command("run", tool_params, ctx.param_resolver) - - execute_commands_locally(commands, Path.cwd()) diff --git a/src/aiet/backend/output_parser.py b/src/aiet/backend/output_parser.py deleted file mode 100644 index 111772a..0000000 --- a/src/aiet/backend/output_parser.py +++ /dev/null @@ -1,176 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Definition of output parsers (including base class OutputParser).""" -import base64 -import json -import re -from abc import ABC -from abc import abstractmethod -from typing import Any -from typing import Dict -from typing import Union - - -class OutputParser(ABC): - """Abstract base class for output parsers.""" - - def __init__(self, name: str) -> None: - """Set up the name of the parser.""" - super().__init__() - self.name = name - - @abstractmethod - def __call__(self, output: bytearray) -> Dict[str, Any]: - """Parse the output and return a map of names to metrics.""" - return {} - - # pylint: disable=no-self-use - def filter_out_parsed_content(self, output: bytearray) -> bytearray: - """ - Filter out the parsed content from the output. - - Does nothing by default. Can be overridden in subclasses. - """ - return output - - -class RegexOutputParser(OutputParser): - """Parser of standard output data using regular expressions.""" - - _TYPE_MAP = {"str": str, "float": float, "int": int} - - def __init__( - self, - name: str, - regex_config: Dict[str, Dict[str, str]], - ) -> None: - """ - Set up the parser with the regular expressions. - - The regex_config is mapping from a name to a dict with keys 'pattern' - and 'type': - - The 'pattern' holds the regular expression that must contain exactly - one capturing parenthesis - - The 'type' can be one of ['str', 'float', 'int']. - - Example: - ``` - {"Metric1": {"pattern": ".*= *(.*)", "type": "str"}} - ``` - - The different regular expressions from the config are combined using - non-capturing parenthesis, i.e. regular expressions must not overlap - if more than one match per line is expected. - """ - super().__init__(name) - - self._verify_config(regex_config) - self._regex_cfg = regex_config - - # Compile regular expression to match in the output - self._regex = re.compile( - "|".join("(?:{0})".format(x["pattern"]) for x in self._regex_cfg.values()) - ) - - def __call__(self, output: bytearray) -> Dict[str, Union[str, float, int]]: - """ - Parse the output and return a map of names to metrics. - - Example: - Assuming a regex_config as used as example in `__init__()` and the - following output: - ``` - Simulation finished: - SIMULATION_STATUS = SUCCESS - Simulation DONE - ``` - Then calling the parser should return the following dict: - ``` - { - "Metric1": "SUCCESS" - } - ``` - """ - metrics = {} - output_str = output.decode("utf-8") - results = self._regex.findall(output_str) - for line_result in results: - for idx, (name, cfg) in enumerate(self._regex_cfg.items()): - # The result(s) returned by findall() are either a single string - # or a tuple (depending on the number of groups etc.) - result = ( - line_result if isinstance(line_result, str) else line_result[idx] - ) - if result: - mapped_result = self._TYPE_MAP[cfg["type"]](result) - metrics[name] = mapped_result - return metrics - - def _verify_config(self, regex_config: Dict[str, Dict[str, str]]) -> None: - """Make sure we have a valid regex_config. - - I.e. - - Exactly one capturing parenthesis per pattern - - Correct types - """ - for name, cfg in regex_config.items(): - # Check that there is one capturing group defined in the pattern. - regex = re.compile(cfg["pattern"]) - if regex.groups != 1: - raise ValueError( - f"Pattern for metric '{name}' must have exactly one " - f"capturing parenthesis, but it has {regex.groups}." - ) - # Check if type is supported - if not cfg["type"] in self._TYPE_MAP: - raise TypeError( - f"Type '{cfg['type']}' for metric '{name}' is not " - f"supported. Choose from: {list(self._TYPE_MAP.keys())}." - ) - - -class Base64OutputParser(OutputParser): - """ - Parser to extract base64-encoded JSON from tagged standard output. - - Example of the tagged output: - ``` - # Encoded JSON: {"test": 1234} - <metrics>eyJ0ZXN0IjogMTIzNH0</metrics> - ``` - """ - - TAG_NAME = "metrics" - - def __init__(self, name: str) -> None: - """Set up the regular expression to extract tagged strings.""" - super().__init__(name) - self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)</{self.TAG_NAME}>") - - def __call__(self, output: bytearray) -> Dict[str, Any]: - """ - Parse the output and return a map of index (as string) to decoded JSON. - - Example: - Using the tagged output from the class docs the parser should return - the following dict: - ``` - { - "0": {"test": 1234} - } - ``` - """ - metrics = {} - output_str = output.decode("utf-8") - results = self._regex.findall(output_str) - for idx, result_base64 in enumerate(results): - result_json = base64.b64decode(result_base64, validate=True) - result = json.loads(result_json) - metrics[str(idx)] = result - - return metrics - - def filter_out_parsed_content(self, output: bytearray) -> bytearray: - """Filter out base64-encoded content from the output.""" - output_str = self._regex.sub("", output.decode("utf-8")) - return bytearray(output_str.encode("utf-8")) diff --git a/src/aiet/backend/protocol.py b/src/aiet/backend/protocol.py deleted file mode 100644 index c621436..0000000 --- a/src/aiet/backend/protocol.py +++ /dev/null @@ -1,325 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain protocol related classes and functions.""" -from abc import ABC -from abc import abstractmethod -from contextlib import closing -from pathlib import Path -from typing import Any -from typing import cast -from typing import Iterable -from typing import Optional -from typing import Tuple -from typing import Union - -import paramiko - -from aiet.backend.common import ConfigurationException -from aiet.backend.config import LocalProtocolConfig -from aiet.backend.config import SSHConfig -from aiet.utils.proc import run_and_wait - - -# Redirect all paramiko thread exceptions to a file otherwise these will be -# printed to stderr. -paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO) - - -class SSHConnectionException(Exception): - """SSH connection exception.""" - - -class SupportsClose(ABC): - """Class indicates support of close operation.""" - - @abstractmethod - def close(self) -> None: - """Close protocol session.""" - - -class SupportsDeploy(ABC): - """Class indicates support of deploy operation.""" - - @abstractmethod - def deploy(self, src: Path, dst: str, retry: bool = True) -> None: - """Abstract method for deploy data.""" - - -class SupportsConnection(ABC): - """Class indicates that protocol uses network connections.""" - - @abstractmethod - def establish_connection(self) -> bool: - """Establish connection with underlying system.""" - - @abstractmethod - def connection_details(self) -> Tuple[str, int]: - """Return connection details (host, port).""" - - -class Protocol(ABC): - """Abstract class for representing the protocol.""" - - def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: - """Initialize the class using a dict.""" - self.__dict__.update(iterable, **kwargs) - self._validate() - - @abstractmethod - def _validate(self) -> None: - """Abstract method for config data validation.""" - - @abstractmethod - def run( - self, command: str, retry: bool = False - ) -> Tuple[int, bytearray, bytearray]: - """ - Abstract method for running commands. - - Returns a tuple: (exit_code, stdout, stderr) - """ - - -class CustomSFTPClient(paramiko.SFTPClient): - """Class for creating a custom sftp client.""" - - def put_dir(self, source: Path, target: str) -> None: - """Upload the source directory to the target path. - - The target directory needs to exists and the last directory of the - source will be created under the target with all its content. - """ - # Create the target directory - self._mkdir(target, ignore_existing=True) - # Create the last directory in the source on the target - self._mkdir("{}/{}".format(target, source.name), ignore_existing=True) - # Go through the whole content of source - for item in sorted(source.glob("**/*")): - relative_path = item.relative_to(source.parent) - remote_target = target / relative_path - if item.is_file(): - self.put(str(item), str(remote_target)) - else: - self._mkdir(str(remote_target), ignore_existing=True) - - def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None: - """Extend mkdir functionality. - - This version adds an option to not fail if the folder exists. - """ - try: - super().mkdir(path, mode) - except IOError as error: - if ignore_existing: - pass - else: - raise error - - -class LocalProtocol(Protocol): - """Class for local protocol.""" - - protocol: str - cwd: Path - - def run( - self, command: str, retry: bool = False - ) -> Tuple[int, bytearray, bytearray]: - """ - Run command locally. - - Returns a tuple: (exit_code, stdout, stderr) - """ - if not isinstance(self.cwd, Path) or not self.cwd.is_dir(): - raise ConfigurationException("Wrong working directory {}".format(self.cwd)) - - stdout = bytearray() - stderr = bytearray() - - return run_and_wait( - command, self.cwd, terminate_on_error=True, out=stdout, err=stderr - ) - - def _validate(self) -> None: - """Validate protocol configuration.""" - assert hasattr(self, "protocol") and self.protocol == "local" - assert hasattr(self, "cwd") - - -class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection): - """Class for SSH protocol.""" - - protocol: str - username: str - password: str - hostname: str - port: int - - def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: - """Initialize the class using a dict.""" - super().__init__(iterable, **kwargs) - # Internal state to store if the system is connectable. It will be set - # to true at the first connection instance - self.client: Optional[paramiko.client.SSHClient] = None - self.port = int(self.port) - - def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: - """ - Run command over SSH. - - Returns a tuple: (exit_code, stdout, stderr) - """ - transport = self._get_transport() - with closing(transport.open_session()) as channel: - # Enable shell's .profile settings and execute command - channel.exec_command("bash -l -c '{}'".format(command)) - exit_status = -1 - stdout = bytearray() - stderr = bytearray() - while True: - if channel.exit_status_ready(): - exit_status = channel.recv_exit_status() - # Call it one last time to read any leftover in the channel - self._recv_stdout_err(channel, stdout, stderr) - break - self._recv_stdout_err(channel, stdout, stderr) - - return exit_status, stdout, stderr - - def deploy(self, src: Path, dst: str, retry: bool = True) -> None: - """Deploy src to remote dst over SSH. - - src and dst should be path to a file or directory. - """ - transport = self._get_transport() - sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport)) - - with closing(sftp): - if src.is_dir(): - sftp.put_dir(src, dst) - elif src.is_file(): - sftp.put(str(src), dst) - else: - raise Exception("Deploy error: file type not supported") - - # After the deployment of files, sync the remote filesystem to flush - # buffers to hard disk - self.run("sync") - - def close(self) -> None: - """Close protocol session.""" - if self.client is not None: - print("Try syncing remote file system...") - # Before stopping the system, we try to run sync to make sure all - # data are flushed on disk. - self.run("sync", retry=False) - self._close_client(self.client) - - def establish_connection(self) -> bool: - """Establish connection with underlying system.""" - if self.client is not None: - return True - - self.client = self._connect() - return self.client is not None - - def _get_transport(self) -> paramiko.transport.Transport: - """Get transport.""" - self.establish_connection() - - if self.client is None: - raise SSHConnectionException( - "Couldn't connect to '{}:{}'.".format(self.hostname, self.port) - ) - - transport = self.client.get_transport() - if not transport: - raise Exception("Unable to get transport") - - return transport - - def connection_details(self) -> Tuple[str, int]: - """Return connection details of underlying system.""" - return (self.hostname, self.port) - - def _connect(self) -> Optional[paramiko.client.SSHClient]: - """Try to establish connection.""" - client: Optional[paramiko.client.SSHClient] = None - try: - client = paramiko.client.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect( - self.hostname, - self.port, - self.username, - self.password, - # next parameters should be set to False to disable authentication - # using ssh keys - allow_agent=False, - look_for_keys=False, - ) - return client - except ( - # OSError raised on first attempt to connect when running inside Docker - OSError, - paramiko.ssh_exception.NoValidConnectionsError, - paramiko.ssh_exception.SSHException, - ): - # even if connection is not established socket could be still - # open, it should be closed - self._close_client(client) - - return None - - @staticmethod - def _close_client(client: Optional[paramiko.client.SSHClient]) -> None: - """Close ssh client.""" - try: - if client is not None: - client.close() - except Exception: # pylint: disable=broad-except - pass - - @classmethod - def _recv_stdout_err( - cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray - ) -> None: - """Read from channel to stdout/stder.""" - chunk_size = 512 - if channel.recv_ready(): - stdout_chunk = channel.recv(chunk_size) - stdout.extend(stdout_chunk) - if channel.recv_stderr_ready(): - stderr_chunk = channel.recv_stderr(chunk_size) - stderr.extend(stderr_chunk) - - def _validate(self) -> None: - """Check if there are all the info for establishing the connection.""" - assert hasattr(self, "protocol") and self.protocol == "ssh" - assert hasattr(self, "username") - assert hasattr(self, "password") - assert hasattr(self, "hostname") - assert hasattr(self, "port") - - -class ProtocolFactory: - """Factory class to return the appropriate Protocol class.""" - - @staticmethod - def get_protocol( - config: Optional[Union[SSHConfig, LocalProtocolConfig]], - **kwargs: Union[str, Path, None] - ) -> Union[SSHProtocol, LocalProtocol]: - """Return the right protocol instance based on the config.""" - if not config: - raise ValueError("No protocol config provided") - - protocol = config["protocol"] - if protocol == "ssh": - return SSHProtocol(config) - - if protocol == "local": - cwd = kwargs.get("cwd") - return LocalProtocol(config, cwd=cwd) - - raise ValueError("Protocol not supported: '{}'".format(protocol)) diff --git a/src/aiet/backend/source.py b/src/aiet/backend/source.py deleted file mode 100644 index dec175a..0000000 --- a/src/aiet/backend/source.py +++ /dev/null @@ -1,209 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain source related classes and functions.""" -import os -import shutil -import tarfile -from abc import ABC -from abc import abstractmethod -from pathlib import Path -from tarfile import TarFile -from typing import Optional -from typing import Union - -from aiet.backend.common import AIET_CONFIG_FILE -from aiet.backend.common import ConfigurationException -from aiet.backend.common import get_backend_config -from aiet.backend.common import is_backend_directory -from aiet.backend.common import load_config -from aiet.backend.config import BackendConfig -from aiet.utils.fs import copy_directory_content - - -class Source(ABC): - """Source class.""" - - @abstractmethod - def name(self) -> Optional[str]: - """Get source name.""" - - @abstractmethod - def config(self) -> Optional[BackendConfig]: - """Get configuration file content.""" - - @abstractmethod - def install_into(self, destination: Path) -> None: - """Install source into destination directory.""" - - @abstractmethod - def create_destination(self) -> bool: - """Return True if destination folder should be created before installation.""" - - -class DirectorySource(Source): - """DirectorySource class.""" - - def __init__(self, directory_path: Path) -> None: - """Create the DirectorySource instance.""" - assert isinstance(directory_path, Path) - self.directory_path = directory_path - - def name(self) -> str: - """Return name of source.""" - return self.directory_path.name - - def config(self) -> Optional[BackendConfig]: - """Return configuration file content.""" - if not is_backend_directory(self.directory_path): - raise ConfigurationException("No configuration file found") - - config_file = get_backend_config(self.directory_path) - return load_config(config_file) - - def install_into(self, destination: Path) -> None: - """Install source into destination directory.""" - if not destination.is_dir(): - raise ConfigurationException("Wrong destination {}".format(destination)) - - if not self.directory_path.is_dir(): - raise ConfigurationException( - "Directory {} does not exist".format(self.directory_path) - ) - - copy_directory_content(self.directory_path, destination) - - def create_destination(self) -> bool: - """Return True if destination folder should be created before installation.""" - return True - - -class TarArchiveSource(Source): - """TarArchiveSource class.""" - - def __init__(self, archive_path: Path) -> None: - """Create the TarArchiveSource class.""" - assert isinstance(archive_path, Path) - self.archive_path = archive_path - self._config: Optional[BackendConfig] = None - self._has_top_level_folder: Optional[bool] = None - self._name: Optional[str] = None - - def _read_archive_content(self) -> None: - """Read various information about archive.""" - # get source name from archive name (everything without extensions) - extensions = "".join(self.archive_path.suffixes) - self._name = self.archive_path.name.rstrip(extensions) - - if not self.archive_path.exists(): - return - - with self._open(self.archive_path) as archive: - try: - config_entry = archive.getmember(AIET_CONFIG_FILE) - self._has_top_level_folder = False - except KeyError as error_no_config: - try: - archive_entries = archive.getnames() - entries_common_prefix = os.path.commonprefix(archive_entries) - top_level_dir = entries_common_prefix.rstrip("/") - - if not top_level_dir: - raise RuntimeError( - "Archive has no top level directory" - ) from error_no_config - - config_path = "{}/{}".format(top_level_dir, AIET_CONFIG_FILE) - - config_entry = archive.getmember(config_path) - self._has_top_level_folder = True - self._name = top_level_dir - except (KeyError, RuntimeError) as error_no_root_dir_or_config: - raise ConfigurationException( - "No configuration file found" - ) from error_no_root_dir_or_config - - content = archive.extractfile(config_entry) - self._config = load_config(content) - - def config(self) -> Optional[BackendConfig]: - """Return configuration file content.""" - if self._config is None: - self._read_archive_content() - - return self._config - - def name(self) -> Optional[str]: - """Return name of the source.""" - if self._name is None: - self._read_archive_content() - - return self._name - - def create_destination(self) -> bool: - """Return True if destination folder must be created before installation.""" - if self._has_top_level_folder is None: - self._read_archive_content() - - return not self._has_top_level_folder - - def install_into(self, destination: Path) -> None: - """Install source into destination directory.""" - if not destination.is_dir(): - raise ConfigurationException("Wrong destination {}".format(destination)) - - with self._open(self.archive_path) as archive: - archive.extractall(destination) - - def _open(self, archive_path: Path) -> TarFile: - """Open archive file.""" - if not archive_path.is_file(): - raise ConfigurationException("File {} does not exist".format(archive_path)) - - if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"): - mode = "r:gz" - else: - raise ConfigurationException( - "Unsupported archive type {}".format(archive_path) - ) - - # The returned TarFile object can be used as a context manager (using - # 'with') by the calling instance. - return tarfile.open( # pylint: disable=consider-using-with - self.archive_path, mode=mode - ) - - -def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]: - """Return appropriate source instance based on provided source path.""" - if source_path.is_file(): - return TarArchiveSource(source_path) - - if source_path.is_dir(): - return DirectorySource(source_path) - - raise ConfigurationException("Unable to read {}".format(source_path)) - - -def create_destination_and_install(source: Source, resource_path: Path) -> None: - """Create destination directory and install source. - - This function is used for actual installation of system/backend New - directory will be created inside :resource_path: if needed If for example - archive contains top level folder then no need to create new directory - """ - destination = resource_path - create_destination = source.create_destination() - - if create_destination: - name = source.name() - if not name: - raise ConfigurationException("Unable to get source name") - - destination = resource_path / name - destination.mkdir() - try: - source.install_into(destination) - except Exception as error: - if create_destination: - shutil.rmtree(destination) - raise error diff --git a/src/aiet/backend/system.py b/src/aiet/backend/system.py deleted file mode 100644 index 48f1bb1..0000000 --- a/src/aiet/backend/system.py +++ /dev/null @@ -1,289 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""System backend module.""" -from pathlib import Path -from typing import Any -from typing import cast -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - -from aiet.backend.common import Backend -from aiet.backend.common import ConfigurationException -from aiet.backend.common import get_backend_configs -from aiet.backend.common import get_backend_directories -from aiet.backend.common import load_config -from aiet.backend.common import remove_backend -from aiet.backend.config import SystemConfig -from aiet.backend.controller import SystemController -from aiet.backend.controller import SystemControllerSingleInstance -from aiet.backend.protocol import ProtocolFactory -from aiet.backend.protocol import SupportsClose -from aiet.backend.protocol import SupportsConnection -from aiet.backend.protocol import SupportsDeploy -from aiet.backend.source import create_destination_and_install -from aiet.backend.source import get_source -from aiet.utils.fs import get_resources - - -def get_available_systems_directory_names() -> List[str]: - """Return a list of directory names for all avialable systems.""" - return [entry.name for entry in get_backend_directories("systems")] - - -def get_available_systems() -> List["System"]: - """Return a list with all available systems.""" - available_systems = [] - for config_json in get_backend_configs("systems"): - config_entries = cast(List[SystemConfig], (load_config(config_json))) - for config_entry in config_entries: - config_entry["config_location"] = config_json.parent.absolute() - system = load_system(config_entry) - available_systems.append(system) - - return sorted(available_systems, key=lambda system: system.name) - - -def get_system(system_name: str) -> Optional["System"]: - """Return a system instance with the same name passed as argument.""" - available_systems = get_available_systems() - for system in available_systems: - if system_name == system.name: - return system - return None - - -def install_system(source_path: Path) -> None: - """Install new system.""" - try: - source = get_source(source_path) - config = cast(List[SystemConfig], source.config()) - systems_to_install = [load_system(entry) for entry in config] - except Exception as error: - raise ConfigurationException("Unable to read system definition") from error - - if not systems_to_install: - raise ConfigurationException("No system definition found") - - available_systems = get_available_systems() - already_installed = [s for s in systems_to_install if s in available_systems] - if already_installed: - names = [system.name for system in already_installed] - raise ConfigurationException( - "Systems [{}] are already installed".format(",".join(names)) - ) - - create_destination_and_install(source, get_resources("systems")) - - -def remove_system(directory_name: str) -> None: - """Remove system.""" - remove_backend(directory_name, "systems") - - -class System(Backend): - """System class.""" - - def __init__(self, config: SystemConfig) -> None: - """Construct the System class using the dictionary passed.""" - super().__init__(config) - - self._setup_data_transfer(config) - self._setup_reporting(config) - - def _setup_data_transfer(self, config: SystemConfig) -> None: - data_transfer_config = config.get("data_transfer") - protocol = ProtocolFactory().get_protocol( - data_transfer_config, cwd=self.config_location - ) - self.protocol = protocol - - def _setup_reporting(self, config: SystemConfig) -> None: - self.reporting = config.get("reporting") - - def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: - """ - Run command on the system. - - Returns a tuple: (exit_code, stdout, stderr) - """ - return self.protocol.run(command, retry) - - def deploy(self, src: Path, dst: str, retry: bool = True) -> None: - """Deploy files to the system.""" - if isinstance(self.protocol, SupportsDeploy): - self.protocol.deploy(src, dst, retry) - - @property - def supports_deploy(self) -> bool: - """Check if protocol supports deploy operation.""" - return isinstance(self.protocol, SupportsDeploy) - - @property - def connectable(self) -> bool: - """Check if protocol supports connection.""" - return isinstance(self.protocol, SupportsConnection) - - def establish_connection(self) -> bool: - """Establish connection with the system.""" - if not isinstance(self.protocol, SupportsConnection): - raise ConfigurationException( - "System {} does not support connections".format(self.name) - ) - - return self.protocol.establish_connection() - - def connection_details(self) -> Tuple[str, int]: - """Return connection details.""" - if not isinstance(self.protocol, SupportsConnection): - raise ConfigurationException( - "System {} does not support connections".format(self.name) - ) - - return self.protocol.connection_details() - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, System): - return False - - return super().__eq__(other) and self.name == other.name - - def get_details(self) -> Dict[str, Any]: - """Return a dictionary with all relevant information of a System.""" - output = { - "type": "system", - "name": self.name, - "description": self.description, - "data_transfer_protocol": self.protocol.protocol, - "commands": self._get_command_details(), - "annotations": self.annotations, - } - - return output - - -class StandaloneSystem(System): - """StandaloneSystem class.""" - - -def get_controller( - single_instance: bool, pid_file_path: Optional[Path] = None -) -> SystemController: - """Get system controller.""" - if single_instance: - return SystemControllerSingleInstance(pid_file_path) - - return SystemController() - - -class ControlledSystem(System): - """ControlledSystem class.""" - - def __init__(self, config: SystemConfig): - """Construct the ControlledSystem class using the dictionary passed.""" - super().__init__(config) - self.controller: Optional[SystemController] = None - - def start( - self, - commands: List[str], - single_instance: bool = True, - pid_file_path: Optional[Path] = None, - ) -> None: - """Launch the system.""" - if ( - not isinstance(self.config_location, Path) - or not self.config_location.is_dir() - ): - raise ConfigurationException( - "System {} has wrong config location".format(self.name) - ) - - self.controller = get_controller(single_instance, pid_file_path) - self.controller.start(commands, self.config_location) - - def is_running(self) -> bool: - """Check if system is running.""" - if not self.controller: - return False - - return self.controller.is_running() - - def get_output(self) -> Tuple[str, str]: - """Return system output.""" - if not self.controller: - return "", "" - - return self.controller.get_output() - - def stop(self, wait: bool = False) -> None: - """Stop the system.""" - if not self.controller: - raise Exception("System has not been started") - - if isinstance(self.protocol, SupportsClose): - try: - self.protocol.close() - except Exception as error: # pylint: disable=broad-except - print(error) - self.controller.stop(wait) - - -def load_system(config: SystemConfig) -> Union[StandaloneSystem, ControlledSystem]: - """Load system based on it's execution type.""" - data_transfer = config.get("data_transfer", {}) - protocol = data_transfer.get("protocol") - populate_shared_params(config) - - if protocol == "ssh": - return ControlledSystem(config) - - if protocol == "local": - return StandaloneSystem(config) - - raise ConfigurationException( - "Unsupported execution type for protocol {}".format(protocol) - ) - - -def populate_shared_params(config: SystemConfig) -> None: - """Populate command parameters with shared parameters.""" - user_params = config.get("user_params") - if not user_params or "shared" not in user_params: - return - - shared_user_params = user_params["shared"] - if not shared_user_params: - return - - only_aliases = all(p.get("alias") for p in shared_user_params) - if not only_aliases: - raise ConfigurationException("All shared parameters should have aliases") - - commands = config.get("commands", {}) - for cmd_name in ["build", "run"]: - command = commands.get(cmd_name) - if command is None: - commands[cmd_name] = [] - cmd_user_params = user_params.get(cmd_name) - if not cmd_user_params: - cmd_user_params = shared_user_params - else: - only_aliases = all(p.get("alias") for p in cmd_user_params) - if not only_aliases: - raise ConfigurationException( - "All parameters for command {} should have aliases".format(cmd_name) - ) - merged_by_alias = { - **{p.get("alias"): p for p in shared_user_params}, - **{p.get("alias"): p for p in cmd_user_params}, - } - cmd_user_params = list(merged_by_alias.values()) - - user_params[cmd_name] = cmd_user_params - - config["commands"] = commands - del user_params["shared"] diff --git a/src/aiet/backend/tool.py b/src/aiet/backend/tool.py deleted file mode 100644 index d643665..0000000 --- a/src/aiet/backend/tool.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tool backend module.""" -from typing import Any -from typing import cast -from typing import Dict -from typing import List -from typing import Optional - -from aiet.backend.common import Backend -from aiet.backend.common import ConfigurationException -from aiet.backend.common import get_backend_configs -from aiet.backend.common import get_backend_directories -from aiet.backend.common import load_application_or_tool_configs -from aiet.backend.common import load_config -from aiet.backend.config import ExtendedToolConfig -from aiet.backend.config import ToolConfig - - -def get_available_tool_directory_names() -> List[str]: - """Return a list of directory names for all available tools.""" - return [entry.name for entry in get_backend_directories("tools")] - - -def get_available_tools() -> List["Tool"]: - """Return a list with all available tools.""" - available_tools = [] - for config_json in get_backend_configs("tools"): - config_entries = cast(List[ExtendedToolConfig], load_config(config_json)) - for config_entry in config_entries: - config_entry["config_location"] = config_json.parent.absolute() - tools = load_tools(config_entry) - available_tools += tools - - return sorted(available_tools, key=lambda tool: tool.name) - - -def get_tool(tool_name: str, system_name: Optional[str] = None) -> List["Tool"]: - """Return a tool instance with the same name passed as argument.""" - return [ - tool - for tool in get_available_tools() - if tool.name == tool_name and (not system_name or tool.can_run_on(system_name)) - ] - - -def get_unique_tool_names(system_name: Optional[str] = None) -> List[str]: - """Extract a list of unique tool names of all tools available.""" - return list( - set( - tool.name - for tool in get_available_tools() - if not system_name or tool.can_run_on(system_name) - ) - ) - - -class Tool(Backend): - """Class for representing a single tool component.""" - - def __init__(self, config: ToolConfig) -> None: - """Construct a Tool instance from a dict.""" - super().__init__(config) - - self.supported_systems = config.get("supported_systems", []) - - if "run" not in self.commands: - raise ConfigurationException("A Tool must have a 'run' command.") - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Tool): - return False - - return ( - super().__eq__(other) - and self.name == other.name - and set(self.supported_systems) == set(other.supported_systems) - ) - - def can_run_on(self, system_name: str) -> bool: - """Check if the tool can run on the system passed as argument.""" - return system_name in self.supported_systems - - def get_details(self) -> Dict[str, Any]: - """Return dictionary with all relevant information of the Tool instance.""" - output = { - "type": "tool", - "name": self.name, - "description": self.description, - "supported_systems": self.supported_systems, - "commands": self._get_command_details(), - } - - return output - - -def load_tools(config: ExtendedToolConfig) -> List[Tool]: - """Load tool. - - Tool configuration could contain different parameters/commands for different - supported systems. For each supported system this function will return separate - Tool instance with appropriate configuration. - """ - configs = load_application_or_tool_configs( - config, ToolConfig, is_system_required=False - ) - tools = [Tool(cfg) for cfg in configs] - return tools |