diff options
Diffstat (limited to 'src/aiet/backend/system.py')
-rw-r--r-- | src/aiet/backend/system.py | 289 |
1 files changed, 289 insertions, 0 deletions
diff --git a/src/aiet/backend/system.py b/src/aiet/backend/system.py new file mode 100644 index 0000000..48f1bb1 --- /dev/null +++ b/src/aiet/backend/system.py @@ -0,0 +1,289 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""System backend module.""" +from pathlib import Path +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from aiet.backend.common import Backend +from aiet.backend.common import ConfigurationException +from aiet.backend.common import get_backend_configs +from aiet.backend.common import get_backend_directories +from aiet.backend.common import load_config +from aiet.backend.common import remove_backend +from aiet.backend.config import SystemConfig +from aiet.backend.controller import SystemController +from aiet.backend.controller import SystemControllerSingleInstance +from aiet.backend.protocol import ProtocolFactory +from aiet.backend.protocol import SupportsClose +from aiet.backend.protocol import SupportsConnection +from aiet.backend.protocol import SupportsDeploy +from aiet.backend.source import create_destination_and_install +from aiet.backend.source import get_source +from aiet.utils.fs import get_resources + + +def get_available_systems_directory_names() -> List[str]: + """Return a list of directory names for all avialable systems.""" + return [entry.name for entry in get_backend_directories("systems")] + + +def get_available_systems() -> List["System"]: + """Return a list with all available systems.""" + available_systems = [] + for config_json in get_backend_configs("systems"): + config_entries = cast(List[SystemConfig], (load_config(config_json))) + for config_entry in config_entries: + config_entry["config_location"] = config_json.parent.absolute() + system = load_system(config_entry) + available_systems.append(system) + + return sorted(available_systems, key=lambda system: system.name) + + +def get_system(system_name: str) -> Optional["System"]: + """Return a system instance with the same name passed as argument.""" + available_systems = get_available_systems() + for system in available_systems: + if system_name == system.name: + return system + return None + + +def install_system(source_path: Path) -> None: + """Install new system.""" + try: + source = get_source(source_path) + config = cast(List[SystemConfig], source.config()) + systems_to_install = [load_system(entry) for entry in config] + except Exception as error: + raise ConfigurationException("Unable to read system definition") from error + + if not systems_to_install: + raise ConfigurationException("No system definition found") + + available_systems = get_available_systems() + already_installed = [s for s in systems_to_install if s in available_systems] + if already_installed: + names = [system.name for system in already_installed] + raise ConfigurationException( + "Systems [{}] are already installed".format(",".join(names)) + ) + + create_destination_and_install(source, get_resources("systems")) + + +def remove_system(directory_name: str) -> None: + """Remove system.""" + remove_backend(directory_name, "systems") + + +class System(Backend): + """System class.""" + + def __init__(self, config: SystemConfig) -> None: + """Construct the System class using the dictionary passed.""" + super().__init__(config) + + self._setup_data_transfer(config) + self._setup_reporting(config) + + def _setup_data_transfer(self, config: SystemConfig) -> None: + data_transfer_config = config.get("data_transfer") + protocol = ProtocolFactory().get_protocol( + data_transfer_config, cwd=self.config_location + ) + self.protocol = protocol + + def _setup_reporting(self, config: SystemConfig) -> None: + self.reporting = config.get("reporting") + + def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: + """ + Run command on the system. + + Returns a tuple: (exit_code, stdout, stderr) + """ + return self.protocol.run(command, retry) + + def deploy(self, src: Path, dst: str, retry: bool = True) -> None: + """Deploy files to the system.""" + if isinstance(self.protocol, SupportsDeploy): + self.protocol.deploy(src, dst, retry) + + @property + def supports_deploy(self) -> bool: + """Check if protocol supports deploy operation.""" + return isinstance(self.protocol, SupportsDeploy) + + @property + def connectable(self) -> bool: + """Check if protocol supports connection.""" + return isinstance(self.protocol, SupportsConnection) + + def establish_connection(self) -> bool: + """Establish connection with the system.""" + if not isinstance(self.protocol, SupportsConnection): + raise ConfigurationException( + "System {} does not support connections".format(self.name) + ) + + return self.protocol.establish_connection() + + def connection_details(self) -> Tuple[str, int]: + """Return connection details.""" + if not isinstance(self.protocol, SupportsConnection): + raise ConfigurationException( + "System {} does not support connections".format(self.name) + ) + + return self.protocol.connection_details() + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, System): + return False + + return super().__eq__(other) and self.name == other.name + + def get_details(self) -> Dict[str, Any]: + """Return a dictionary with all relevant information of a System.""" + output = { + "type": "system", + "name": self.name, + "description": self.description, + "data_transfer_protocol": self.protocol.protocol, + "commands": self._get_command_details(), + "annotations": self.annotations, + } + + return output + + +class StandaloneSystem(System): + """StandaloneSystem class.""" + + +def get_controller( + single_instance: bool, pid_file_path: Optional[Path] = None +) -> SystemController: + """Get system controller.""" + if single_instance: + return SystemControllerSingleInstance(pid_file_path) + + return SystemController() + + +class ControlledSystem(System): + """ControlledSystem class.""" + + def __init__(self, config: SystemConfig): + """Construct the ControlledSystem class using the dictionary passed.""" + super().__init__(config) + self.controller: Optional[SystemController] = None + + def start( + self, + commands: List[str], + single_instance: bool = True, + pid_file_path: Optional[Path] = None, + ) -> None: + """Launch the system.""" + if ( + not isinstance(self.config_location, Path) + or not self.config_location.is_dir() + ): + raise ConfigurationException( + "System {} has wrong config location".format(self.name) + ) + + self.controller = get_controller(single_instance, pid_file_path) + self.controller.start(commands, self.config_location) + + def is_running(self) -> bool: + """Check if system is running.""" + if not self.controller: + return False + + return self.controller.is_running() + + def get_output(self) -> Tuple[str, str]: + """Return system output.""" + if not self.controller: + return "", "" + + return self.controller.get_output() + + def stop(self, wait: bool = False) -> None: + """Stop the system.""" + if not self.controller: + raise Exception("System has not been started") + + if isinstance(self.protocol, SupportsClose): + try: + self.protocol.close() + except Exception as error: # pylint: disable=broad-except + print(error) + self.controller.stop(wait) + + +def load_system(config: SystemConfig) -> Union[StandaloneSystem, ControlledSystem]: + """Load system based on it's execution type.""" + data_transfer = config.get("data_transfer", {}) + protocol = data_transfer.get("protocol") + populate_shared_params(config) + + if protocol == "ssh": + return ControlledSystem(config) + + if protocol == "local": + return StandaloneSystem(config) + + raise ConfigurationException( + "Unsupported execution type for protocol {}".format(protocol) + ) + + +def populate_shared_params(config: SystemConfig) -> None: + """Populate command parameters with shared parameters.""" + user_params = config.get("user_params") + if not user_params or "shared" not in user_params: + return + + shared_user_params = user_params["shared"] + if not shared_user_params: + return + + only_aliases = all(p.get("alias") for p in shared_user_params) + if not only_aliases: + raise ConfigurationException("All shared parameters should have aliases") + + commands = config.get("commands", {}) + for cmd_name in ["build", "run"]: + command = commands.get(cmd_name) + if command is None: + commands[cmd_name] = [] + cmd_user_params = user_params.get(cmd_name) + if not cmd_user_params: + cmd_user_params = shared_user_params + else: + only_aliases = all(p.get("alias") for p in cmd_user_params) + if not only_aliases: + raise ConfigurationException( + "All parameters for command {} should have aliases".format(cmd_name) + ) + merged_by_alias = { + **{p.get("alias"): p for p in shared_user_params}, + **{p.get("alias"): p for p in cmd_user_params}, + } + cmd_user_params = list(merged_by_alias.values()) + + user_params[cmd_name] = cmd_user_params + + config["commands"] = commands + del user_params["shared"] |