aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/backend/system.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/aiet/backend/system.py')
-rw-r--r--src/aiet/backend/system.py289
1 files changed, 289 insertions, 0 deletions
diff --git a/src/aiet/backend/system.py b/src/aiet/backend/system.py
new file mode 100644
index 0000000..48f1bb1
--- /dev/null
+++ b/src/aiet/backend/system.py
@@ -0,0 +1,289 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""System backend module."""
+from pathlib import Path
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+from aiet.backend.common import Backend
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import get_backend_configs
+from aiet.backend.common import get_backend_directories
+from aiet.backend.common import load_config
+from aiet.backend.common import remove_backend
+from aiet.backend.config import SystemConfig
+from aiet.backend.controller import SystemController
+from aiet.backend.controller import SystemControllerSingleInstance
+from aiet.backend.protocol import ProtocolFactory
+from aiet.backend.protocol import SupportsClose
+from aiet.backend.protocol import SupportsConnection
+from aiet.backend.protocol import SupportsDeploy
+from aiet.backend.source import create_destination_and_install
+from aiet.backend.source import get_source
+from aiet.utils.fs import get_resources
+
+
+def get_available_systems_directory_names() -> List[str]:
+ """Return a list of directory names for all avialable systems."""
+ return [entry.name for entry in get_backend_directories("systems")]
+
+
+def get_available_systems() -> List["System"]:
+ """Return a list with all available systems."""
+ available_systems = []
+ for config_json in get_backend_configs("systems"):
+ config_entries = cast(List[SystemConfig], (load_config(config_json)))
+ for config_entry in config_entries:
+ config_entry["config_location"] = config_json.parent.absolute()
+ system = load_system(config_entry)
+ available_systems.append(system)
+
+ return sorted(available_systems, key=lambda system: system.name)
+
+
+def get_system(system_name: str) -> Optional["System"]:
+ """Return a system instance with the same name passed as argument."""
+ available_systems = get_available_systems()
+ for system in available_systems:
+ if system_name == system.name:
+ return system
+ return None
+
+
+def install_system(source_path: Path) -> None:
+ """Install new system."""
+ try:
+ source = get_source(source_path)
+ config = cast(List[SystemConfig], source.config())
+ systems_to_install = [load_system(entry) for entry in config]
+ except Exception as error:
+ raise ConfigurationException("Unable to read system definition") from error
+
+ if not systems_to_install:
+ raise ConfigurationException("No system definition found")
+
+ available_systems = get_available_systems()
+ already_installed = [s for s in systems_to_install if s in available_systems]
+ if already_installed:
+ names = [system.name for system in already_installed]
+ raise ConfigurationException(
+ "Systems [{}] are already installed".format(",".join(names))
+ )
+
+ create_destination_and_install(source, get_resources("systems"))
+
+
+def remove_system(directory_name: str) -> None:
+ """Remove system."""
+ remove_backend(directory_name, "systems")
+
+
+class System(Backend):
+ """System class."""
+
+ def __init__(self, config: SystemConfig) -> None:
+ """Construct the System class using the dictionary passed."""
+ super().__init__(config)
+
+ self._setup_data_transfer(config)
+ self._setup_reporting(config)
+
+ def _setup_data_transfer(self, config: SystemConfig) -> None:
+ data_transfer_config = config.get("data_transfer")
+ protocol = ProtocolFactory().get_protocol(
+ data_transfer_config, cwd=self.config_location
+ )
+ self.protocol = protocol
+
+ def _setup_reporting(self, config: SystemConfig) -> None:
+ self.reporting = config.get("reporting")
+
+ def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]:
+ """
+ Run command on the system.
+
+ Returns a tuple: (exit_code, stdout, stderr)
+ """
+ return self.protocol.run(command, retry)
+
+ def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
+ """Deploy files to the system."""
+ if isinstance(self.protocol, SupportsDeploy):
+ self.protocol.deploy(src, dst, retry)
+
+ @property
+ def supports_deploy(self) -> bool:
+ """Check if protocol supports deploy operation."""
+ return isinstance(self.protocol, SupportsDeploy)
+
+ @property
+ def connectable(self) -> bool:
+ """Check if protocol supports connection."""
+ return isinstance(self.protocol, SupportsConnection)
+
+ def establish_connection(self) -> bool:
+ """Establish connection with the system."""
+ if not isinstance(self.protocol, SupportsConnection):
+ raise ConfigurationException(
+ "System {} does not support connections".format(self.name)
+ )
+
+ return self.protocol.establish_connection()
+
+ def connection_details(self) -> Tuple[str, int]:
+ """Return connection details."""
+ if not isinstance(self.protocol, SupportsConnection):
+ raise ConfigurationException(
+ "System {} does not support connections".format(self.name)
+ )
+
+ return self.protocol.connection_details()
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, System):
+ return False
+
+ return super().__eq__(other) and self.name == other.name
+
+ def get_details(self) -> Dict[str, Any]:
+ """Return a dictionary with all relevant information of a System."""
+ output = {
+ "type": "system",
+ "name": self.name,
+ "description": self.description,
+ "data_transfer_protocol": self.protocol.protocol,
+ "commands": self._get_command_details(),
+ "annotations": self.annotations,
+ }
+
+ return output
+
+
+class StandaloneSystem(System):
+ """StandaloneSystem class."""
+
+
+def get_controller(
+ single_instance: bool, pid_file_path: Optional[Path] = None
+) -> SystemController:
+ """Get system controller."""
+ if single_instance:
+ return SystemControllerSingleInstance(pid_file_path)
+
+ return SystemController()
+
+
+class ControlledSystem(System):
+ """ControlledSystem class."""
+
+ def __init__(self, config: SystemConfig):
+ """Construct the ControlledSystem class using the dictionary passed."""
+ super().__init__(config)
+ self.controller: Optional[SystemController] = None
+
+ def start(
+ self,
+ commands: List[str],
+ single_instance: bool = True,
+ pid_file_path: Optional[Path] = None,
+ ) -> None:
+ """Launch the system."""
+ if (
+ not isinstance(self.config_location, Path)
+ or not self.config_location.is_dir()
+ ):
+ raise ConfigurationException(
+ "System {} has wrong config location".format(self.name)
+ )
+
+ self.controller = get_controller(single_instance, pid_file_path)
+ self.controller.start(commands, self.config_location)
+
+ def is_running(self) -> bool:
+ """Check if system is running."""
+ if not self.controller:
+ return False
+
+ return self.controller.is_running()
+
+ def get_output(self) -> Tuple[str, str]:
+ """Return system output."""
+ if not self.controller:
+ return "", ""
+
+ return self.controller.get_output()
+
+ def stop(self, wait: bool = False) -> None:
+ """Stop the system."""
+ if not self.controller:
+ raise Exception("System has not been started")
+
+ if isinstance(self.protocol, SupportsClose):
+ try:
+ self.protocol.close()
+ except Exception as error: # pylint: disable=broad-except
+ print(error)
+ self.controller.stop(wait)
+
+
+def load_system(config: SystemConfig) -> Union[StandaloneSystem, ControlledSystem]:
+ """Load system based on it's execution type."""
+ data_transfer = config.get("data_transfer", {})
+ protocol = data_transfer.get("protocol")
+ populate_shared_params(config)
+
+ if protocol == "ssh":
+ return ControlledSystem(config)
+
+ if protocol == "local":
+ return StandaloneSystem(config)
+
+ raise ConfigurationException(
+ "Unsupported execution type for protocol {}".format(protocol)
+ )
+
+
+def populate_shared_params(config: SystemConfig) -> None:
+ """Populate command parameters with shared parameters."""
+ user_params = config.get("user_params")
+ if not user_params or "shared" not in user_params:
+ return
+
+ shared_user_params = user_params["shared"]
+ if not shared_user_params:
+ return
+
+ only_aliases = all(p.get("alias") for p in shared_user_params)
+ if not only_aliases:
+ raise ConfigurationException("All shared parameters should have aliases")
+
+ commands = config.get("commands", {})
+ for cmd_name in ["build", "run"]:
+ command = commands.get(cmd_name)
+ if command is None:
+ commands[cmd_name] = []
+ cmd_user_params = user_params.get(cmd_name)
+ if not cmd_user_params:
+ cmd_user_params = shared_user_params
+ else:
+ only_aliases = all(p.get("alias") for p in cmd_user_params)
+ if not only_aliases:
+ raise ConfigurationException(
+ "All parameters for command {} should have aliases".format(cmd_name)
+ )
+ merged_by_alias = {
+ **{p.get("alias"): p for p in shared_user_params},
+ **{p.get("alias"): p for p in cmd_user_params},
+ }
+ cmd_user_params = list(merged_by_alias.values())
+
+ user_params[cmd_name] = cmd_user_params
+
+ config["commands"] = commands
+ del user_params["shared"]