aboutsummaryrefslogtreecommitdiff
path: root/tests/test_utils_py_manager.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_utils_py_manager.py')
-rw-r--r--tests/test_utils_py_manager.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/tests/test_utils_py_manager.py b/tests/test_utils_py_manager.py
new file mode 100644
index 0000000..e41680d
--- /dev/null
+++ b/tests/test_utils_py_manager.py
@@ -0,0 +1,73 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for python package manager."""
+import sys
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.utils.py_manager import get_package_manager
+from mlia.utils.py_manager import PyPackageManager
+
+
+def test_get_package_manager() -> None:
+ """Test function get_package_manager."""
+ manager = get_package_manager()
+ assert isinstance(manager, PyPackageManager)
+
+
+@pytest.fixture(name="mock_check_call")
+def mock_check_call_fixture(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
+ """Mock check_call function."""
+ mock_check_call = MagicMock()
+ monkeypatch.setattr("mlia.utils.py_manager.check_call", mock_check_call)
+
+ return mock_check_call
+
+
+def test_py_package_manager_metadata() -> None:
+ """Test getting package status."""
+ manager = PyPackageManager()
+ assert manager.package_installed("pytest")
+ assert manager.packages_installed(["pytest", "mlia"])
+
+
+def test_py_package_manager_install(mock_check_call: MagicMock) -> None:
+ """Test package installation."""
+ manager = PyPackageManager()
+ with pytest.raises(ValueError, match="No package names provided"):
+ manager.install([])
+
+ manager.install(["mlia", "pytest"])
+ mock_check_call.assert_called_once_with(
+ [
+ sys.executable,
+ "-m",
+ "pip",
+ "--disable-pip-version-check",
+ "install",
+ "mlia",
+ "pytest",
+ ]
+ )
+
+
+def test_py_package_manager_uninstall(mock_check_call: MagicMock) -> None:
+ """Test package removal."""
+ manager = PyPackageManager()
+ with pytest.raises(ValueError, match="No package names provided"):
+ manager.uninstall([])
+
+ manager.uninstall(["mlia", "pytest"])
+ mock_check_call.assert_called_once_with(
+ [
+ sys.executable,
+ "-m",
+ "pip",
+ "--disable-pip-version-check",
+ "uninstall",
+ "--yes",
+ "mlia",
+ "pytest",
+ ]
+ )