aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/backend/protocol.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-06-28 10:29:35 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-07-08 10:57:19 +0100
commitc9b4089b3037b5943565d76242d3016b8776f8d2 (patch)
tree3de24f79dedf0f26f492a7fa1562bf684e13a055 /src/aiet/backend/protocol.py
parentba2c7fcccf37e8c81946f0776714c64f73191787 (diff)
downloadmlia-c9b4089b3037b5943565d76242d3016b8776f8d2.tar.gz
MLIA-546 Merge AIET into MLIA
Merge the deprecated AIET interface for backend execution into MLIA: - Execute backends directly (without subprocess and the aiet CLI) - Fix issues with the unit tests - Remove src/aiet and tests/aiet - Re-factor code to replace 'aiet' with 'backend' - Adapt and improve unit tests after re-factoring - Remove dependencies that are not needed anymore (click and cloup) Change-Id: I450734c6a3f705ba9afde41862b29e797e511f7c
Diffstat (limited to 'src/aiet/backend/protocol.py')
-rw-r--r--src/aiet/backend/protocol.py325
1 files changed, 0 insertions, 325 deletions
diff --git a/src/aiet/backend/protocol.py b/src/aiet/backend/protocol.py
deleted file mode 100644
index c621436..0000000
--- a/src/aiet/backend/protocol.py
+++ /dev/null
@@ -1,325 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Contain protocol related classes and functions."""
-from abc import ABC
-from abc import abstractmethod
-from contextlib import closing
-from pathlib import Path
-from typing import Any
-from typing import cast
-from typing import Iterable
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import paramiko
-
-from aiet.backend.common import ConfigurationException
-from aiet.backend.config import LocalProtocolConfig
-from aiet.backend.config import SSHConfig
-from aiet.utils.proc import run_and_wait
-
-
-# Redirect all paramiko thread exceptions to a file otherwise these will be
-# printed to stderr.
-paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO)
-
-
-class SSHConnectionException(Exception):
- """SSH connection exception."""
-
-
-class SupportsClose(ABC):
- """Class indicates support of close operation."""
-
- @abstractmethod
- def close(self) -> None:
- """Close protocol session."""
-
-
-class SupportsDeploy(ABC):
- """Class indicates support of deploy operation."""
-
- @abstractmethod
- def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
- """Abstract method for deploy data."""
-
-
-class SupportsConnection(ABC):
- """Class indicates that protocol uses network connections."""
-
- @abstractmethod
- def establish_connection(self) -> bool:
- """Establish connection with underlying system."""
-
- @abstractmethod
- def connection_details(self) -> Tuple[str, int]:
- """Return connection details (host, port)."""
-
-
-class Protocol(ABC):
- """Abstract class for representing the protocol."""
-
- def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None:
- """Initialize the class using a dict."""
- self.__dict__.update(iterable, **kwargs)
- self._validate()
-
- @abstractmethod
- def _validate(self) -> None:
- """Abstract method for config data validation."""
-
- @abstractmethod
- def run(
- self, command: str, retry: bool = False
- ) -> Tuple[int, bytearray, bytearray]:
- """
- Abstract method for running commands.
-
- Returns a tuple: (exit_code, stdout, stderr)
- """
-
-
-class CustomSFTPClient(paramiko.SFTPClient):
- """Class for creating a custom sftp client."""
-
- def put_dir(self, source: Path, target: str) -> None:
- """Upload the source directory to the target path.
-
- The target directory needs to exists and the last directory of the
- source will be created under the target with all its content.
- """
- # Create the target directory
- self._mkdir(target, ignore_existing=True)
- # Create the last directory in the source on the target
- self._mkdir("{}/{}".format(target, source.name), ignore_existing=True)
- # Go through the whole content of source
- for item in sorted(source.glob("**/*")):
- relative_path = item.relative_to(source.parent)
- remote_target = target / relative_path
- if item.is_file():
- self.put(str(item), str(remote_target))
- else:
- self._mkdir(str(remote_target), ignore_existing=True)
-
- def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None:
- """Extend mkdir functionality.
-
- This version adds an option to not fail if the folder exists.
- """
- try:
- super().mkdir(path, mode)
- except IOError as error:
- if ignore_existing:
- pass
- else:
- raise error
-
-
-class LocalProtocol(Protocol):
- """Class for local protocol."""
-
- protocol: str
- cwd: Path
-
- def run(
- self, command: str, retry: bool = False
- ) -> Tuple[int, bytearray, bytearray]:
- """
- Run command locally.
-
- Returns a tuple: (exit_code, stdout, stderr)
- """
- if not isinstance(self.cwd, Path) or not self.cwd.is_dir():
- raise ConfigurationException("Wrong working directory {}".format(self.cwd))
-
- stdout = bytearray()
- stderr = bytearray()
-
- return run_and_wait(
- command, self.cwd, terminate_on_error=True, out=stdout, err=stderr
- )
-
- def _validate(self) -> None:
- """Validate protocol configuration."""
- assert hasattr(self, "protocol") and self.protocol == "local"
- assert hasattr(self, "cwd")
-
-
-class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection):
- """Class for SSH protocol."""
-
- protocol: str
- username: str
- password: str
- hostname: str
- port: int
-
- def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None:
- """Initialize the class using a dict."""
- super().__init__(iterable, **kwargs)
- # Internal state to store if the system is connectable. It will be set
- # to true at the first connection instance
- self.client: Optional[paramiko.client.SSHClient] = None
- self.port = int(self.port)
-
- def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]:
- """
- Run command over SSH.
-
- Returns a tuple: (exit_code, stdout, stderr)
- """
- transport = self._get_transport()
- with closing(transport.open_session()) as channel:
- # Enable shell's .profile settings and execute command
- channel.exec_command("bash -l -c '{}'".format(command))
- exit_status = -1
- stdout = bytearray()
- stderr = bytearray()
- while True:
- if channel.exit_status_ready():
- exit_status = channel.recv_exit_status()
- # Call it one last time to read any leftover in the channel
- self._recv_stdout_err(channel, stdout, stderr)
- break
- self._recv_stdout_err(channel, stdout, stderr)
-
- return exit_status, stdout, stderr
-
- def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
- """Deploy src to remote dst over SSH.
-
- src and dst should be path to a file or directory.
- """
- transport = self._get_transport()
- sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport))
-
- with closing(sftp):
- if src.is_dir():
- sftp.put_dir(src, dst)
- elif src.is_file():
- sftp.put(str(src), dst)
- else:
- raise Exception("Deploy error: file type not supported")
-
- # After the deployment of files, sync the remote filesystem to flush
- # buffers to hard disk
- self.run("sync")
-
- def close(self) -> None:
- """Close protocol session."""
- if self.client is not None:
- print("Try syncing remote file system...")
- # Before stopping the system, we try to run sync to make sure all
- # data are flushed on disk.
- self.run("sync", retry=False)
- self._close_client(self.client)
-
- def establish_connection(self) -> bool:
- """Establish connection with underlying system."""
- if self.client is not None:
- return True
-
- self.client = self._connect()
- return self.client is not None
-
- def _get_transport(self) -> paramiko.transport.Transport:
- """Get transport."""
- self.establish_connection()
-
- if self.client is None:
- raise SSHConnectionException(
- "Couldn't connect to '{}:{}'.".format(self.hostname, self.port)
- )
-
- transport = self.client.get_transport()
- if not transport:
- raise Exception("Unable to get transport")
-
- return transport
-
- def connection_details(self) -> Tuple[str, int]:
- """Return connection details of underlying system."""
- return (self.hostname, self.port)
-
- def _connect(self) -> Optional[paramiko.client.SSHClient]:
- """Try to establish connection."""
- client: Optional[paramiko.client.SSHClient] = None
- try:
- client = paramiko.client.SSHClient()
- client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- client.connect(
- self.hostname,
- self.port,
- self.username,
- self.password,
- # next parameters should be set to False to disable authentication
- # using ssh keys
- allow_agent=False,
- look_for_keys=False,
- )
- return client
- except (
- # OSError raised on first attempt to connect when running inside Docker
- OSError,
- paramiko.ssh_exception.NoValidConnectionsError,
- paramiko.ssh_exception.SSHException,
- ):
- # even if connection is not established socket could be still
- # open, it should be closed
- self._close_client(client)
-
- return None
-
- @staticmethod
- def _close_client(client: Optional[paramiko.client.SSHClient]) -> None:
- """Close ssh client."""
- try:
- if client is not None:
- client.close()
- except Exception: # pylint: disable=broad-except
- pass
-
- @classmethod
- def _recv_stdout_err(
- cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray
- ) -> None:
- """Read from channel to stdout/stder."""
- chunk_size = 512
- if channel.recv_ready():
- stdout_chunk = channel.recv(chunk_size)
- stdout.extend(stdout_chunk)
- if channel.recv_stderr_ready():
- stderr_chunk = channel.recv_stderr(chunk_size)
- stderr.extend(stderr_chunk)
-
- def _validate(self) -> None:
- """Check if there are all the info for establishing the connection."""
- assert hasattr(self, "protocol") and self.protocol == "ssh"
- assert hasattr(self, "username")
- assert hasattr(self, "password")
- assert hasattr(self, "hostname")
- assert hasattr(self, "port")
-
-
-class ProtocolFactory:
- """Factory class to return the appropriate Protocol class."""
-
- @staticmethod
- def get_protocol(
- config: Optional[Union[SSHConfig, LocalProtocolConfig]],
- **kwargs: Union[str, Path, None]
- ) -> Union[SSHProtocol, LocalProtocol]:
- """Return the right protocol instance based on the config."""
- if not config:
- raise ValueError("No protocol config provided")
-
- protocol = config["protocol"]
- if protocol == "ssh":
- return SSHProtocol(config)
-
- if protocol == "local":
- cwd = kwargs.get("cwd")
- return LocalProtocol(config, cwd=cwd)
-
- raise ValueError("Protocol not supported: '{}'".format(protocol))