From 0efca3cadbad5517a59884576ddb90cfe7ac30f8 Mon Sep 17 00:00:00 2001 From: Diego Russo Date: Mon, 30 May 2022 13:34:14 +0100 Subject: Add MLIA codebase Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd --- src/aiet/backend/protocol.py | 325 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 325 insertions(+) create mode 100644 src/aiet/backend/protocol.py (limited to 'src/aiet/backend/protocol.py') diff --git a/src/aiet/backend/protocol.py b/src/aiet/backend/protocol.py new file mode 100644 index 0000000..c621436 --- /dev/null +++ b/src/aiet/backend/protocol.py @@ -0,0 +1,325 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain protocol related classes and functions.""" +from abc import ABC +from abc import abstractmethod +from contextlib import closing +from pathlib import Path +from typing import Any +from typing import cast +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import Union + +import paramiko + +from aiet.backend.common import ConfigurationException +from aiet.backend.config import LocalProtocolConfig +from aiet.backend.config import SSHConfig +from aiet.utils.proc import run_and_wait + + +# Redirect all paramiko thread exceptions to a file otherwise these will be +# printed to stderr. +paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO) + + +class SSHConnectionException(Exception): + """SSH connection exception.""" + + +class SupportsClose(ABC): + """Class indicates support of close operation.""" + + @abstractmethod + def close(self) -> None: + """Close protocol session.""" + + +class SupportsDeploy(ABC): + """Class indicates support of deploy operation.""" + + @abstractmethod + def deploy(self, src: Path, dst: str, retry: bool = True) -> None: + """Abstract method for deploy data.""" + + +class SupportsConnection(ABC): + """Class indicates that protocol uses network connections.""" + + @abstractmethod + def establish_connection(self) -> bool: + """Establish connection with underlying system.""" + + @abstractmethod + def connection_details(self) -> Tuple[str, int]: + """Return connection details (host, port).""" + + +class Protocol(ABC): + """Abstract class for representing the protocol.""" + + def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: + """Initialize the class using a dict.""" + self.__dict__.update(iterable, **kwargs) + self._validate() + + @abstractmethod + def _validate(self) -> None: + """Abstract method for config data validation.""" + + @abstractmethod + def run( + self, command: str, retry: bool = False + ) -> Tuple[int, bytearray, bytearray]: + """ + Abstract method for running commands. + + Returns a tuple: (exit_code, stdout, stderr) + """ + + +class CustomSFTPClient(paramiko.SFTPClient): + """Class for creating a custom sftp client.""" + + def put_dir(self, source: Path, target: str) -> None: + """Upload the source directory to the target path. + + The target directory needs to exists and the last directory of the + source will be created under the target with all its content. + """ + # Create the target directory + self._mkdir(target, ignore_existing=True) + # Create the last directory in the source on the target + self._mkdir("{}/{}".format(target, source.name), ignore_existing=True) + # Go through the whole content of source + for item in sorted(source.glob("**/*")): + relative_path = item.relative_to(source.parent) + remote_target = target / relative_path + if item.is_file(): + self.put(str(item), str(remote_target)) + else: + self._mkdir(str(remote_target), ignore_existing=True) + + def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None: + """Extend mkdir functionality. + + This version adds an option to not fail if the folder exists. + """ + try: + super().mkdir(path, mode) + except IOError as error: + if ignore_existing: + pass + else: + raise error + + +class LocalProtocol(Protocol): + """Class for local protocol.""" + + protocol: str + cwd: Path + + def run( + self, command: str, retry: bool = False + ) -> Tuple[int, bytearray, bytearray]: + """ + Run command locally. + + Returns a tuple: (exit_code, stdout, stderr) + """ + if not isinstance(self.cwd, Path) or not self.cwd.is_dir(): + raise ConfigurationException("Wrong working directory {}".format(self.cwd)) + + stdout = bytearray() + stderr = bytearray() + + return run_and_wait( + command, self.cwd, terminate_on_error=True, out=stdout, err=stderr + ) + + def _validate(self) -> None: + """Validate protocol configuration.""" + assert hasattr(self, "protocol") and self.protocol == "local" + assert hasattr(self, "cwd") + + +class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection): + """Class for SSH protocol.""" + + protocol: str + username: str + password: str + hostname: str + port: int + + def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: + """Initialize the class using a dict.""" + super().__init__(iterable, **kwargs) + # Internal state to store if the system is connectable. It will be set + # to true at the first connection instance + self.client: Optional[paramiko.client.SSHClient] = None + self.port = int(self.port) + + def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: + """ + Run command over SSH. + + Returns a tuple: (exit_code, stdout, stderr) + """ + transport = self._get_transport() + with closing(transport.open_session()) as channel: + # Enable shell's .profile settings and execute command + channel.exec_command("bash -l -c '{}'".format(command)) + exit_status = -1 + stdout = bytearray() + stderr = bytearray() + while True: + if channel.exit_status_ready(): + exit_status = channel.recv_exit_status() + # Call it one last time to read any leftover in the channel + self._recv_stdout_err(channel, stdout, stderr) + break + self._recv_stdout_err(channel, stdout, stderr) + + return exit_status, stdout, stderr + + def deploy(self, src: Path, dst: str, retry: bool = True) -> None: + """Deploy src to remote dst over SSH. + + src and dst should be path to a file or directory. + """ + transport = self._get_transport() + sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport)) + + with closing(sftp): + if src.is_dir(): + sftp.put_dir(src, dst) + elif src.is_file(): + sftp.put(str(src), dst) + else: + raise Exception("Deploy error: file type not supported") + + # After the deployment of files, sync the remote filesystem to flush + # buffers to hard disk + self.run("sync") + + def close(self) -> None: + """Close protocol session.""" + if self.client is not None: + print("Try syncing remote file system...") + # Before stopping the system, we try to run sync to make sure all + # data are flushed on disk. + self.run("sync", retry=False) + self._close_client(self.client) + + def establish_connection(self) -> bool: + """Establish connection with underlying system.""" + if self.client is not None: + return True + + self.client = self._connect() + return self.client is not None + + def _get_transport(self) -> paramiko.transport.Transport: + """Get transport.""" + self.establish_connection() + + if self.client is None: + raise SSHConnectionException( + "Couldn't connect to '{}:{}'.".format(self.hostname, self.port) + ) + + transport = self.client.get_transport() + if not transport: + raise Exception("Unable to get transport") + + return transport + + def connection_details(self) -> Tuple[str, int]: + """Return connection details of underlying system.""" + return (self.hostname, self.port) + + def _connect(self) -> Optional[paramiko.client.SSHClient]: + """Try to establish connection.""" + client: Optional[paramiko.client.SSHClient] = None + try: + client = paramiko.client.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect( + self.hostname, + self.port, + self.username, + self.password, + # next parameters should be set to False to disable authentication + # using ssh keys + allow_agent=False, + look_for_keys=False, + ) + return client + except ( + # OSError raised on first attempt to connect when running inside Docker + OSError, + paramiko.ssh_exception.NoValidConnectionsError, + paramiko.ssh_exception.SSHException, + ): + # even if connection is not established socket could be still + # open, it should be closed + self._close_client(client) + + return None + + @staticmethod + def _close_client(client: Optional[paramiko.client.SSHClient]) -> None: + """Close ssh client.""" + try: + if client is not None: + client.close() + except Exception: # pylint: disable=broad-except + pass + + @classmethod + def _recv_stdout_err( + cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray + ) -> None: + """Read from channel to stdout/stder.""" + chunk_size = 512 + if channel.recv_ready(): + stdout_chunk = channel.recv(chunk_size) + stdout.extend(stdout_chunk) + if channel.recv_stderr_ready(): + stderr_chunk = channel.recv_stderr(chunk_size) + stderr.extend(stderr_chunk) + + def _validate(self) -> None: + """Check if there are all the info for establishing the connection.""" + assert hasattr(self, "protocol") and self.protocol == "ssh" + assert hasattr(self, "username") + assert hasattr(self, "password") + assert hasattr(self, "hostname") + assert hasattr(self, "port") + + +class ProtocolFactory: + """Factory class to return the appropriate Protocol class.""" + + @staticmethod + def get_protocol( + config: Optional[Union[SSHConfig, LocalProtocolConfig]], + **kwargs: Union[str, Path, None] + ) -> Union[SSHProtocol, LocalProtocol]: + """Return the right protocol instance based on the config.""" + if not config: + raise ValueError("No protocol config provided") + + protocol = config["protocol"] + if protocol == "ssh": + return SSHProtocol(config) + + if protocol == "local": + cwd = kwargs.get("cwd") + return LocalProtocol(config, cwd=cwd) + + raise ValueError("Protocol not supported: '{}'".format(protocol)) -- cgit v1.2.1