aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/test_backend_protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/mlia/test_backend_protocol.py')
-rw-r--r--tests/mlia/test_backend_protocol.py231
1 files changed, 0 insertions, 231 deletions
diff --git a/tests/mlia/test_backend_protocol.py b/tests/mlia/test_backend_protocol.py
deleted file mode 100644
index 35e9986..0000000
--- a/tests/mlia/test_backend_protocol.py
+++ /dev/null
@@ -1,231 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-# pylint: disable=no-self-use,attribute-defined-outside-init,protected-access
-"""Tests for the protocol backend module."""
-from contextlib import ExitStack as does_not_raise
-from pathlib import Path
-from typing import Any
-from unittest.mock import MagicMock
-
-import paramiko
-import pytest
-
-from mlia.backend.common import ConfigurationException
-from mlia.backend.config import LocalProtocolConfig
-from mlia.backend.protocol import CustomSFTPClient
-from mlia.backend.protocol import LocalProtocol
-from mlia.backend.protocol import ProtocolFactory
-from mlia.backend.protocol import SSHProtocol
-
-
-class TestProtocolFactory:
- """Test ProtocolFactory class."""
-
- @pytest.mark.parametrize(
- "config, expected_class, exception",
- [
- (
- {
- "protocol": "ssh",
- "username": "user",
- "password": "pass",
- "hostname": "hostname",
- "port": "22",
- },
- SSHProtocol,
- does_not_raise(),
- ),
- ({"protocol": "local"}, LocalProtocol, does_not_raise()),
- (
- {"protocol": "something"},
- None,
- pytest.raises(Exception, match="Protocol not supported"),
- ),
- (None, None, pytest.raises(Exception, match="No protocol config provided")),
- ],
- )
- def test_get_protocol(
- self, config: Any, expected_class: type, exception: Any
- ) -> None:
- """Test get_protocol method."""
- factory = ProtocolFactory()
- with exception:
- protocol = factory.get_protocol(config)
- assert isinstance(protocol, expected_class)
-
-
-class TestLocalProtocol:
- """Test local protocol."""
-
- def test_local_protocol_run_command(self) -> None:
- """Test local protocol run command."""
- config = LocalProtocolConfig(protocol="local")
- protocol = LocalProtocol(config, cwd=Path("/tmp"))
- ret, stdout, stderr = protocol.run("pwd")
- assert ret == 0
- assert stdout.decode("utf-8").strip() == "/tmp"
- assert stderr.decode("utf-8") == ""
-
- def test_local_protocol_run_wrong_cwd(self) -> None:
- """Execution should fail if wrong working directory provided."""
- config = LocalProtocolConfig(protocol="local")
- protocol = LocalProtocol(config, cwd=Path("unknown_directory"))
- with pytest.raises(
- ConfigurationException, match="Wrong working directory unknown_directory"
- ):
- protocol.run("pwd")
-
-
-class TestSSHProtocol:
- """Test SSH protocol."""
-
- @pytest.fixture(autouse=True)
- def setup_method(self, monkeypatch: Any) -> None:
- """Set up protocol mocks."""
- self.mock_ssh_client = MagicMock(spec=paramiko.client.SSHClient)
-
- self.mock_ssh_channel = (
- self.mock_ssh_client.get_transport.return_value.open_session.return_value
- )
- self.mock_ssh_channel.mock_add_spec(spec=paramiko.channel.Channel)
- self.mock_ssh_channel.exit_status_ready.side_effect = [False, True]
- self.mock_ssh_channel.recv_exit_status.return_value = True
- self.mock_ssh_channel.recv_ready.side_effect = [False, True]
- self.mock_ssh_channel.recv_stderr_ready.side_effect = [False, True]
-
- monkeypatch.setattr(
- "mlia.backend.protocol.paramiko.client.SSHClient",
- MagicMock(return_value=self.mock_ssh_client),
- )
-
- self.mock_sftp_client = MagicMock(spec=CustomSFTPClient)
- monkeypatch.setattr(
- "mlia.backend.protocol.CustomSFTPClient.from_transport",
- MagicMock(return_value=self.mock_sftp_client),
- )
-
- ssh_config = {
- "protocol": "ssh",
- "username": "user",
- "password": "pass",
- "hostname": "hostname",
- "port": "22",
- }
- self.protocol = SSHProtocol(ssh_config)
-
- def test_unable_create_ssh_client(self, monkeypatch: Any) -> None:
- """Test that command should fail if unable to create ssh client instance."""
- monkeypatch.setattr(
- "mlia.backend.protocol.paramiko.client.SSHClient",
- MagicMock(side_effect=OSError("Error!")),
- )
-
- with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"):
- self.protocol.run("command_example", retry=False)
-
- def test_ssh_protocol_run_command(self) -> None:
- """Test that command run via ssh successfully."""
- self.protocol.run("command_example")
- self.mock_ssh_channel.exec_command.assert_called_once()
-
- def test_ssh_protocol_run_command_connect_failed(self) -> None:
- """Test that if connection is not possible then correct exception is raised."""
- self.mock_ssh_client.connect.side_effect = OSError("Unable to connect")
- self.mock_ssh_client.close.side_effect = Exception("Error!")
-
- with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"):
- self.protocol.run("command_example", retry=False)
-
- def test_ssh_protocol_run_command_bad_transport(self) -> None:
- """Test that command should fail if unable to get transport."""
- self.mock_ssh_client.get_transport.return_value = None
-
- with pytest.raises(Exception, match="Unable to get transport"):
- self.protocol.run("command_example", retry=False)
-
- def test_ssh_protocol_deploy_command_file(
- self, test_applications_path: Path
- ) -> None:
- """Test that files could be deployed over ssh."""
- file_for_deploy = test_applications_path / "readme.txt"
- dest = "/tmp/dest"
-
- self.protocol.deploy(file_for_deploy, dest)
- self.mock_sftp_client.put.assert_called_once_with(str(file_for_deploy), dest)
-
- def test_ssh_protocol_deploy_command_unknown_file(self) -> None:
- """Test that deploy will fail if file does not exist."""
- with pytest.raises(Exception, match="Deploy error: file type not supported"):
- self.protocol.deploy(Path("unknown_file"), "/tmp/dest")
-
- def test_ssh_protocol_deploy_command_bad_transport(self) -> None:
- """Test that deploy should fail if unable to get transport."""
- self.mock_ssh_client.get_transport.return_value = None
-
- with pytest.raises(Exception, match="Unable to get transport"):
- self.protocol.deploy(Path("some_file"), "/tmp/dest")
-
- def test_ssh_protocol_deploy_command_directory(
- self, test_resources_path: Path
- ) -> None:
- """Test that directory could be deployed over ssh."""
- directory_for_deploy = test_resources_path / "scripts"
- dest = "/tmp/dest"
-
- self.protocol.deploy(directory_for_deploy, dest)
- self.mock_sftp_client.put_dir.assert_called_once_with(
- directory_for_deploy, dest
- )
-
- @pytest.mark.parametrize("establish_connection", (True, False))
- def test_ssh_protocol_close(self, establish_connection: bool) -> None:
- """Test protocol close operation."""
- if establish_connection:
- self.protocol.establish_connection()
- self.protocol.close()
-
- call_count = 1 if establish_connection else 0
- assert self.mock_ssh_channel.exec_command.call_count == call_count
-
- def test_connection_details(self) -> None:
- """Test getting connection details."""
- assert self.protocol.connection_details() == ("hostname", 22)
-
-
-class TestCustomSFTPClient:
- """Test CustomSFTPClient class."""
-
- @pytest.fixture(autouse=True)
- def setup_method(self, monkeypatch: Any) -> None:
- """Set up mocks for CustomSFTPClient instance."""
- self.mock_mkdir = MagicMock()
- self.mock_put = MagicMock()
- monkeypatch.setattr(
- "mlia.backend.protocol.paramiko.SFTPClient.__init__",
- MagicMock(return_value=None),
- )
- monkeypatch.setattr(
- "mlia.backend.protocol.paramiko.SFTPClient.mkdir", self.mock_mkdir
- )
- monkeypatch.setattr(
- "mlia.backend.protocol.paramiko.SFTPClient.put", self.mock_put
- )
-
- self.sftp_client = CustomSFTPClient(MagicMock())
-
- def test_put_dir(self, test_systems_path: Path) -> None:
- """Test deploying directory to remote host."""
- directory_for_deploy = test_systems_path / "system1"
-
- self.sftp_client.put_dir(directory_for_deploy, "/tmp/dest")
- assert self.mock_put.call_count == 3
- assert self.mock_mkdir.call_count == 3
-
- def test_mkdir(self) -> None:
- """Test creating directory on remote host."""
- self.mock_mkdir.side_effect = IOError("Cannot create directory")
-
- self.sftp_client._mkdir("new_directory", ignore_existing=True)
-
- with pytest.raises(IOError, match="Cannot create directory"):
- self.sftp_client._mkdir("new_directory", ignore_existing=False)