aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/backend/protocol.py
blob: c621436b66de350699a0cc8447d38a37cb1e4501 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Contain protocol related classes and functions."""
from abc import ABC
from abc import abstractmethod
from contextlib import closing
from pathlib import Path
from typing import Any
from typing import cast
from typing import Iterable
from typing import Optional
from typing import Tuple
from typing import Union

import paramiko

from aiet.backend.common import ConfigurationException
from aiet.backend.config import LocalProtocolConfig
from aiet.backend.config import SSHConfig
from aiet.utils.proc import run_and_wait


# Redirect all paramiko thread exceptions to a file otherwise these will be
# printed to stderr.
paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO)


class SSHConnectionException(Exception):
    """SSH connection exception."""


class SupportsClose(ABC):
    """Class indicates support of close operation."""

    @abstractmethod
    def close(self) -> None:
        """Close protocol session."""


class SupportsDeploy(ABC):
    """Class indicates support of deploy operation."""

    @abstractmethod
    def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
        """Abstract method for deploy data."""


class SupportsConnection(ABC):
    """Class indicates that protocol uses network connections."""

    @abstractmethod
    def establish_connection(self) -> bool:
        """Establish connection with underlying system."""

    @abstractmethod
    def connection_details(self) -> Tuple[str, int]:
        """Return connection details (host, port)."""


class Protocol(ABC):
    """Abstract class for representing the protocol."""

    def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None:
        """Initialize the class using a dict."""
        self.__dict__.update(iterable, **kwargs)
        self._validate()

    @abstractmethod
    def _validate(self) -> None:
        """Abstract method for config data validation."""

    @abstractmethod
    def run(
        self, command: str, retry: bool = False
    ) -> Tuple[int, bytearray, bytearray]:
        """
        Abstract method for running commands.

        Returns a tuple: (exit_code, stdout, stderr)
        """


class CustomSFTPClient(paramiko.SFTPClient):
    """Class for creating a custom sftp client."""

    def put_dir(self, source: Path, target: str) -> None:
        """Upload the source directory to the target path.

        The target directory needs to exists and the last directory of the
        source will be created under the target with all its content.
        """
        # Create the target directory
        self._mkdir(target, ignore_existing=True)
        # Create the last directory in the source on the target
        self._mkdir("{}/{}".format(target, source.name), ignore_existing=True)
        # Go through the whole content of source
        for item in sorted(source.glob("**/*")):
            relative_path = item.relative_to(source.parent)
            remote_target = target / relative_path
            if item.is_file():
                self.put(str(item), str(remote_target))
            else:
                self._mkdir(str(remote_target), ignore_existing=True)

    def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None:
        """Extend mkdir functionality.

        This version adds an option to not fail if the folder exists.
        """
        try:
            super().mkdir(path, mode)
        except IOError as error:
            if ignore_existing:
                pass
            else:
                raise error


class LocalProtocol(Protocol):
    """Class for local protocol."""

    protocol: str
    cwd: Path

    def run(
        self, command: str, retry: bool = False
    ) -> Tuple[int, bytearray, bytearray]:
        """
        Run command locally.

        Returns a tuple: (exit_code, stdout, stderr)
        """
        if not isinstance(self.cwd, Path) or not self.cwd.is_dir():
            raise ConfigurationException("Wrong working directory {}".format(self.cwd))

        stdout = bytearray()
        stderr = bytearray()

        return run_and_wait(
            command, self.cwd, terminate_on_error=True, out=stdout, err=stderr
        )

    def _validate(self) -> None:
        """Validate protocol configuration."""
        assert hasattr(self, "protocol") and self.protocol == "local"
        assert hasattr(self, "cwd")


class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection):
    """Class for SSH protocol."""

    protocol: str
    username: str
    password: str
    hostname: str
    port: int

    def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None:
        """Initialize the class using a dict."""
        super().__init__(iterable, **kwargs)
        # Internal state to store if the system is connectable. It will be set
        # to true at the first connection instance
        self.client: Optional[paramiko.client.SSHClient] = None
        self.port = int(self.port)

    def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]:
        """
        Run command over SSH.

        Returns a tuple: (exit_code, stdout, stderr)
        """
        transport = self._get_transport()
        with closing(transport.open_session()) as channel:
            # Enable shell's .profile settings and execute command
            channel.exec_command("bash -l -c '{}'".format(command))
            exit_status = -1
            stdout = bytearray()
            stderr = bytearray()
            while True:
                if channel.exit_status_ready():
                    exit_status = channel.recv_exit_status()
                    # Call it one last time to read any leftover in the channel
                    self._recv_stdout_err(channel, stdout, stderr)
                    break
                self._recv_stdout_err(channel, stdout, stderr)

        return exit_status, stdout, stderr

    def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
        """Deploy src to remote dst over SSH.

        src and dst should be path to a file or directory.
        """
        transport = self._get_transport()
        sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport))

        with closing(sftp):
            if src.is_dir():
                sftp.put_dir(src, dst)
            elif src.is_file():
                sftp.put(str(src), dst)
            else:
                raise Exception("Deploy error: file type not supported")

        # After the deployment of files, sync the remote filesystem to flush
        # buffers to hard disk
        self.run("sync")

    def close(self) -> None:
        """Close protocol session."""
        if self.client is not None:
            print("Try syncing remote file system...")
            # Before stopping the system, we try to run sync to make sure all
            # data are flushed on disk.
            self.run("sync", retry=False)
            self._close_client(self.client)

    def establish_connection(self) -> bool:
        """Establish connection with underlying system."""
        if self.client is not None:
            return True

        self.client = self._connect()
        return self.client is not None

    def _get_transport(self) -> paramiko.transport.Transport:
        """Get transport."""
        self.establish_connection()

        if self.client is None:
            raise SSHConnectionException(
                "Couldn't connect to '{}:{}'.".format(self.hostname, self.port)
            )

        transport = self.client.get_transport()
        if not transport:
            raise Exception("Unable to get transport")

        return transport

    def connection_details(self) -> Tuple[str, int]:
        """Return connection details of underlying system."""
        return (self.hostname, self.port)

    def _connect(self) -> Optional[paramiko.client.SSHClient]:
        """Try to establish connection."""
        client: Optional[paramiko.client.SSHClient] = None
        try:
            client = paramiko.client.SSHClient()
            client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            client.connect(
                self.hostname,
                self.port,
                self.username,
                self.password,
                # next parameters should be set to False to disable authentication
                # using ssh keys
                allow_agent=False,
                look_for_keys=False,
            )
            return client
        except (
            # OSError raised on first attempt to connect when running inside Docker
            OSError,
            paramiko.ssh_exception.NoValidConnectionsError,
            paramiko.ssh_exception.SSHException,
        ):
            # even if connection is not established socket could be still
            # open, it should be closed
            self._close_client(client)

            return None

    @staticmethod
    def _close_client(client: Optional[paramiko.client.SSHClient]) -> None:
        """Close ssh client."""
        try:
            if client is not None:
                client.close()
        except Exception:  # pylint: disable=broad-except
            pass

    @classmethod
    def _recv_stdout_err(
        cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray
    ) -> None:
        """Read from channel to stdout/stder."""
        chunk_size = 512
        if channel.recv_ready():
            stdout_chunk = channel.recv(chunk_size)
            stdout.extend(stdout_chunk)
        if channel.recv_stderr_ready():
            stderr_chunk = channel.recv_stderr(chunk_size)
            stderr.extend(stderr_chunk)

    def _validate(self) -> None:
        """Check if there are all the info for establishing the connection."""
        assert hasattr(self, "protocol") and self.protocol == "ssh"
        assert hasattr(self, "username")
        assert hasattr(self, "password")
        assert hasattr(self, "hostname")
        assert hasattr(self, "port")


class ProtocolFactory:
    """Factory class to return the appropriate Protocol class."""

    @staticmethod
    def get_protocol(
        config: Optional[Union[SSHConfig, LocalProtocolConfig]],
        **kwargs: Union[str, Path, None]
    ) -> Union[SSHProtocol, LocalProtocol]:
        """Return the right protocol instance based on the config."""
        if not config:
            raise ValueError("No protocol config provided")

        protocol = config["protocol"]
        if protocol == "ssh":
            return SSHProtocol(config)

        if protocol == "local":
            cwd = kwargs.get("cwd")
            return LocalProtocol(config, cwd=cwd)

        raise ValueError("Protocol not supported: '{}'".format(protocol))