aboutsummaryrefslogtreecommitdiff
path: root/tests/aiet/test_backend_protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/aiet/test_backend_protocol.py')
-rw-r--r--tests/aiet/test_backend_protocol.py231
1 files changed, 231 insertions, 0 deletions
diff --git a/tests/aiet/test_backend_protocol.py b/tests/aiet/test_backend_protocol.py
new file mode 100644
index 0000000..2103238
--- /dev/null
+++ b/tests/aiet/test_backend_protocol.py
@@ -0,0 +1,231 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use,attribute-defined-outside-init,protected-access
+"""Tests for the protocol backend module."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from unittest.mock import MagicMock
+
+import paramiko
+import pytest
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.config import LocalProtocolConfig
+from aiet.backend.protocol import CustomSFTPClient
+from aiet.backend.protocol import LocalProtocol
+from aiet.backend.protocol import ProtocolFactory
+from aiet.backend.protocol import SSHProtocol
+
+
+class TestProtocolFactory:
+ """Test ProtocolFactory class."""
+
+ @pytest.mark.parametrize(
+ "config, expected_class, exception",
+ [
+ (
+ {
+ "protocol": "ssh",
+ "username": "user",
+ "password": "pass",
+ "hostname": "hostname",
+ "port": "22",
+ },
+ SSHProtocol,
+ does_not_raise(),
+ ),
+ ({"protocol": "local"}, LocalProtocol, does_not_raise()),
+ (
+ {"protocol": "something"},
+ None,
+ pytest.raises(Exception, match="Protocol not supported"),
+ ),
+ (None, None, pytest.raises(Exception, match="No protocol config provided")),
+ ],
+ )
+ def test_get_protocol(
+ self, config: Any, expected_class: type, exception: Any
+ ) -> None:
+ """Test get_protocol method."""
+ factory = ProtocolFactory()
+ with exception:
+ protocol = factory.get_protocol(config)
+ assert isinstance(protocol, expected_class)
+
+
+class TestLocalProtocol:
+ """Test local protocol."""
+
+ def test_local_protocol_run_command(self) -> None:
+ """Test local protocol run command."""
+ config = LocalProtocolConfig(protocol="local")
+ protocol = LocalProtocol(config, cwd=Path("/tmp"))
+ ret, stdout, stderr = protocol.run("pwd")
+ assert ret == 0
+ assert stdout.decode("utf-8").strip() == "/tmp"
+ assert stderr.decode("utf-8") == ""
+
+ def test_local_protocol_run_wrong_cwd(self) -> None:
+ """Execution should fail if wrong working directory provided."""
+ config = LocalProtocolConfig(protocol="local")
+ protocol = LocalProtocol(config, cwd=Path("unknown_directory"))
+ with pytest.raises(
+ ConfigurationException, match="Wrong working directory unknown_directory"
+ ):
+ protocol.run("pwd")
+
+
+class TestSSHProtocol:
+ """Test SSH protocol."""
+
+ @pytest.fixture(autouse=True)
+ def setup_method(self, monkeypatch: Any) -> None:
+ """Set up protocol mocks."""
+ self.mock_ssh_client = MagicMock(spec=paramiko.client.SSHClient)
+
+ self.mock_ssh_channel = (
+ self.mock_ssh_client.get_transport.return_value.open_session.return_value
+ )
+ self.mock_ssh_channel.mock_add_spec(spec=paramiko.channel.Channel)
+ self.mock_ssh_channel.exit_status_ready.side_effect = [False, True]
+ self.mock_ssh_channel.recv_exit_status.return_value = True
+ self.mock_ssh_channel.recv_ready.side_effect = [False, True]
+ self.mock_ssh_channel.recv_stderr_ready.side_effect = [False, True]
+
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.client.SSHClient",
+ MagicMock(return_value=self.mock_ssh_client),
+ )
+
+ self.mock_sftp_client = MagicMock(spec=CustomSFTPClient)
+ monkeypatch.setattr(
+ "aiet.backend.protocol.CustomSFTPClient.from_transport",
+ MagicMock(return_value=self.mock_sftp_client),
+ )
+
+ ssh_config = {
+ "protocol": "ssh",
+ "username": "user",
+ "password": "pass",
+ "hostname": "hostname",
+ "port": "22",
+ }
+ self.protocol = SSHProtocol(ssh_config)
+
+ def test_unable_create_ssh_client(self, monkeypatch: Any) -> None:
+ """Test that command should fail if unable to create ssh client instance."""
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.client.SSHClient",
+ MagicMock(side_effect=OSError("Error!")),
+ )
+
+ with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"):
+ self.protocol.run("command_example", retry=False)
+
+ def test_ssh_protocol_run_command(self) -> None:
+ """Test that command run via ssh successfully."""
+ self.protocol.run("command_example")
+ self.mock_ssh_channel.exec_command.assert_called_once()
+
+ def test_ssh_protocol_run_command_connect_failed(self) -> None:
+ """Test that if connection is not possible then correct exception is raised."""
+ self.mock_ssh_client.connect.side_effect = OSError("Unable to connect")
+ self.mock_ssh_client.close.side_effect = Exception("Error!")
+
+ with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"):
+ self.protocol.run("command_example", retry=False)
+
+ def test_ssh_protocol_run_command_bad_transport(self) -> None:
+ """Test that command should fail if unable to get transport."""
+ self.mock_ssh_client.get_transport.return_value = None
+
+ with pytest.raises(Exception, match="Unable to get transport"):
+ self.protocol.run("command_example", retry=False)
+
+ def test_ssh_protocol_deploy_command_file(
+ self, test_applications_path: Path
+ ) -> None:
+ """Test that files could be deployed over ssh."""
+ file_for_deploy = test_applications_path / "readme.txt"
+ dest = "/tmp/dest"
+
+ self.protocol.deploy(file_for_deploy, dest)
+ self.mock_sftp_client.put.assert_called_once_with(str(file_for_deploy), dest)
+
+ def test_ssh_protocol_deploy_command_unknown_file(self) -> None:
+ """Test that deploy will fail if file does not exist."""
+ with pytest.raises(Exception, match="Deploy error: file type not supported"):
+ self.protocol.deploy(Path("unknown_file"), "/tmp/dest")
+
+ def test_ssh_protocol_deploy_command_bad_transport(self) -> None:
+ """Test that deploy should fail if unable to get transport."""
+ self.mock_ssh_client.get_transport.return_value = None
+
+ with pytest.raises(Exception, match="Unable to get transport"):
+ self.protocol.deploy(Path("some_file"), "/tmp/dest")
+
+ def test_ssh_protocol_deploy_command_directory(
+ self, test_resources_path: Path
+ ) -> None:
+ """Test that directory could be deployed over ssh."""
+ directory_for_deploy = test_resources_path / "scripts"
+ dest = "/tmp/dest"
+
+ self.protocol.deploy(directory_for_deploy, dest)
+ self.mock_sftp_client.put_dir.assert_called_once_with(
+ directory_for_deploy, dest
+ )
+
+ @pytest.mark.parametrize("establish_connection", (True, False))
+ def test_ssh_protocol_close(self, establish_connection: bool) -> None:
+ """Test protocol close operation."""
+ if establish_connection:
+ self.protocol.establish_connection()
+ self.protocol.close()
+
+ call_count = 1 if establish_connection else 0
+ assert self.mock_ssh_channel.exec_command.call_count == call_count
+
+ def test_connection_details(self) -> None:
+ """Test getting connection details."""
+ assert self.protocol.connection_details() == ("hostname", 22)
+
+
+class TestCustomSFTPClient:
+ """Test CustomSFTPClient class."""
+
+ @pytest.fixture(autouse=True)
+ def setup_method(self, monkeypatch: Any) -> None:
+ """Set up mocks for CustomSFTPClient instance."""
+ self.mock_mkdir = MagicMock()
+ self.mock_put = MagicMock()
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.SFTPClient.__init__",
+ MagicMock(return_value=None),
+ )
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.SFTPClient.mkdir", self.mock_mkdir
+ )
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.SFTPClient.put", self.mock_put
+ )
+
+ self.sftp_client = CustomSFTPClient(MagicMock())
+
+ def test_put_dir(self, test_systems_path: Path) -> None:
+ """Test deploying directory to remote host."""
+ directory_for_deploy = test_systems_path / "system1"
+
+ self.sftp_client.put_dir(directory_for_deploy, "/tmp/dest")
+ assert self.mock_put.call_count == 3
+ assert self.mock_mkdir.call_count == 3
+
+ def test_mkdir(self) -> None:
+ """Test creating directory on remote host."""
+ self.mock_mkdir.side_effect = IOError("Cannot create directory")
+
+ self.sftp_client._mkdir("new_directory", ignore_existing=True)
+
+ with pytest.raises(IOError, match="Cannot create directory"):
+ self.sftp_client._mkdir("new_directory", ignore_existing=False)