aboutsummaryrefslogtreecommitdiff
path: root/tests/aiet/test_backend_protocol.py
blob: 2103238ddb8dbd4ba88c2ee4bfe53526cae59fba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=no-self-use,attribute-defined-outside-init,protected-access
"""Tests for the protocol backend module."""
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock

import paramiko
import pytest

from aiet.backend.common import ConfigurationException
from aiet.backend.config import LocalProtocolConfig
from aiet.backend.protocol import CustomSFTPClient
from aiet.backend.protocol import LocalProtocol
from aiet.backend.protocol import ProtocolFactory
from aiet.backend.protocol import SSHProtocol


class TestProtocolFactory:
    """Test ProtocolFactory class."""

    @pytest.mark.parametrize(
        "config, expected_class, exception",
        [
            (
                {
                    "protocol": "ssh",
                    "username": "user",
                    "password": "pass",
                    "hostname": "hostname",
                    "port": "22",
                },
                SSHProtocol,
                does_not_raise(),
            ),
            ({"protocol": "local"}, LocalProtocol, does_not_raise()),
            (
                {"protocol": "something"},
                None,
                pytest.raises(Exception, match="Protocol not supported"),
            ),
            (None, None, pytest.raises(Exception, match="No protocol config provided")),
        ],
    )
    def test_get_protocol(
        self, config: Any, expected_class: type, exception: Any
    ) -> None:
        """Test get_protocol method."""
        factory = ProtocolFactory()
        with exception:
            protocol = factory.get_protocol(config)
            assert isinstance(protocol, expected_class)


class TestLocalProtocol:
    """Test local protocol."""

    def test_local_protocol_run_command(self) -> None:
        """Test local protocol run command."""
        config = LocalProtocolConfig(protocol="local")
        protocol = LocalProtocol(config, cwd=Path("/tmp"))
        ret, stdout, stderr = protocol.run("pwd")
        assert ret == 0
        assert stdout.decode("utf-8").strip() == "/tmp"
        assert stderr.decode("utf-8") == ""

    def test_local_protocol_run_wrong_cwd(self) -> None:
        """Execution should fail if wrong working directory provided."""
        config = LocalProtocolConfig(protocol="local")
        protocol = LocalProtocol(config, cwd=Path("unknown_directory"))
        with pytest.raises(
            ConfigurationException, match="Wrong working directory unknown_directory"
        ):
            protocol.run("pwd")


class TestSSHProtocol:
    """Test SSH protocol."""

    @pytest.fixture(autouse=True)
    def setup_method(self, monkeypatch: Any) -> None:
        """Set up protocol mocks."""
        self.mock_ssh_client = MagicMock(spec=paramiko.client.SSHClient)

        self.mock_ssh_channel = (
            self.mock_ssh_client.get_transport.return_value.open_session.return_value
        )
        self.mock_ssh_channel.mock_add_spec(spec=paramiko.channel.Channel)
        self.mock_ssh_channel.exit_status_ready.side_effect = [False, True]
        self.mock_ssh_channel.recv_exit_status.return_value = True
        self.mock_ssh_channel.recv_ready.side_effect = [False, True]
        self.mock_ssh_channel.recv_stderr_ready.side_effect = [False, True]

        monkeypatch.setattr(
            "aiet.backend.protocol.paramiko.client.SSHClient",
            MagicMock(return_value=self.mock_ssh_client),
        )

        self.mock_sftp_client = MagicMock(spec=CustomSFTPClient)
        monkeypatch.setattr(
            "aiet.backend.protocol.CustomSFTPClient.from_transport",
            MagicMock(return_value=self.mock_sftp_client),
        )

        ssh_config = {
            "protocol": "ssh",
            "username": "user",
            "password": "pass",
            "hostname": "hostname",
            "port": "22",
        }
        self.protocol = SSHProtocol(ssh_config)

    def test_unable_create_ssh_client(self, monkeypatch: Any) -> None:
        """Test that command should fail if unable to create ssh client instance."""
        monkeypatch.setattr(
            "aiet.backend.protocol.paramiko.client.SSHClient",
            MagicMock(side_effect=OSError("Error!")),
        )

        with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"):
            self.protocol.run("command_example", retry=False)

    def test_ssh_protocol_run_command(self) -> None:
        """Test that command run via ssh successfully."""
        self.protocol.run("command_example")
        self.mock_ssh_channel.exec_command.assert_called_once()

    def test_ssh_protocol_run_command_connect_failed(self) -> None:
        """Test that if connection is not possible then correct exception is raised."""
        self.mock_ssh_client.connect.side_effect = OSError("Unable to connect")
        self.mock_ssh_client.close.side_effect = Exception("Error!")

        with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"):
            self.protocol.run("command_example", retry=False)

    def test_ssh_protocol_run_command_bad_transport(self) -> None:
        """Test that command should fail if unable to get transport."""
        self.mock_ssh_client.get_transport.return_value = None

        with pytest.raises(Exception, match="Unable to get transport"):
            self.protocol.run("command_example", retry=False)

    def test_ssh_protocol_deploy_command_file(
        self, test_applications_path: Path
    ) -> None:
        """Test that files could be deployed over ssh."""
        file_for_deploy = test_applications_path / "readme.txt"
        dest = "/tmp/dest"

        self.protocol.deploy(file_for_deploy, dest)
        self.mock_sftp_client.put.assert_called_once_with(str(file_for_deploy), dest)

    def test_ssh_protocol_deploy_command_unknown_file(self) -> None:
        """Test that deploy will fail if file does not exist."""
        with pytest.raises(Exception, match="Deploy error: file type not supported"):
            self.protocol.deploy(Path("unknown_file"), "/tmp/dest")

    def test_ssh_protocol_deploy_command_bad_transport(self) -> None:
        """Test that deploy should fail if unable to get transport."""
        self.mock_ssh_client.get_transport.return_value = None

        with pytest.raises(Exception, match="Unable to get transport"):
            self.protocol.deploy(Path("some_file"), "/tmp/dest")

    def test_ssh_protocol_deploy_command_directory(
        self, test_resources_path: Path
    ) -> None:
        """Test that directory could be deployed over ssh."""
        directory_for_deploy = test_resources_path / "scripts"
        dest = "/tmp/dest"

        self.protocol.deploy(directory_for_deploy, dest)
        self.mock_sftp_client.put_dir.assert_called_once_with(
            directory_for_deploy, dest
        )

    @pytest.mark.parametrize("establish_connection", (True, False))
    def test_ssh_protocol_close(self, establish_connection: bool) -> None:
        """Test protocol close operation."""
        if establish_connection:
            self.protocol.establish_connection()
        self.protocol.close()

        call_count = 1 if establish_connection else 0
        assert self.mock_ssh_channel.exec_command.call_count == call_count

    def test_connection_details(self) -> None:
        """Test getting connection details."""
        assert self.protocol.connection_details() == ("hostname", 22)


class TestCustomSFTPClient:
    """Test CustomSFTPClient class."""

    @pytest.fixture(autouse=True)
    def setup_method(self, monkeypatch: Any) -> None:
        """Set up mocks for CustomSFTPClient instance."""
        self.mock_mkdir = MagicMock()
        self.mock_put = MagicMock()
        monkeypatch.setattr(
            "aiet.backend.protocol.paramiko.SFTPClient.__init__",
            MagicMock(return_value=None),
        )
        monkeypatch.setattr(
            "aiet.backend.protocol.paramiko.SFTPClient.mkdir", self.mock_mkdir
        )
        monkeypatch.setattr(
            "aiet.backend.protocol.paramiko.SFTPClient.put", self.mock_put
        )

        self.sftp_client = CustomSFTPClient(MagicMock())

    def test_put_dir(self, test_systems_path: Path) -> None:
        """Test deploying directory to remote host."""
        directory_for_deploy = test_systems_path / "system1"

        self.sftp_client.put_dir(directory_for_deploy, "/tmp/dest")
        assert self.mock_put.call_count == 3
        assert self.mock_mkdir.call_count == 3

    def test_mkdir(self) -> None:
        """Test creating directory on remote host."""
        self.mock_mkdir.side_effect = IOError("Cannot create directory")

        self.sftp_client._mkdir("new_directory", ignore_existing=True)

        with pytest.raises(IOError, match="Cannot create directory"):
            self.sftp_client._mkdir("new_directory", ignore_existing=False)